devdz commited on
Commit
5eae308
·
verified ·
1 Parent(s): 9538b62

Upload 369 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. t5x-main/.github/workflows/build.yaml +39 -0
  3. t5x-main/CONTRIBUTING.md +1 -0
  4. t5x-main/LICENSE +202 -0
  5. t5x-main/README.md +525 -0
  6. t5x-main/docs/_static/t5x_theme.css +23 -0
  7. t5x-main/docs/_templates/autosummary/t5x_module.rst +23 -0
  8. t5x-main/docs/api_reference/index.rst +100 -0
  9. t5x-main/docs/api_reference/t5x.adafactor.rst +7 -0
  10. t5x-main/docs/api_reference/t5x.binary_search.rst +7 -0
  11. t5x-main/docs/api_reference/t5x.checkpoint_importer.rst +7 -0
  12. t5x-main/docs/api_reference/t5x.checkpoint_utils.rst +7 -0
  13. t5x-main/docs/api_reference/t5x.checkpoints.rst +7 -0
  14. t5x-main/docs/api_reference/t5x.config_utils.rst +7 -0
  15. t5x-main/docs/api_reference/t5x.decoding.rst +7 -0
  16. t5x-main/docs/api_reference/t5x.eval.rst +7 -0
  17. t5x-main/docs/api_reference/t5x.gin_utils.rst +7 -0
  18. t5x-main/docs/api_reference/t5x.infer.rst +7 -0
  19. t5x-main/docs/api_reference/t5x.interactive_model.rst +7 -0
  20. t5x-main/docs/api_reference/t5x.losses.rst +7 -0
  21. t5x-main/docs/api_reference/t5x.main.rst +7 -0
  22. t5x-main/docs/api_reference/t5x.metrics.rst +7 -0
  23. t5x-main/docs/api_reference/t5x.models.rst +7 -0
  24. t5x-main/docs/api_reference/t5x.optimizers.rst +7 -0
  25. t5x-main/docs/api_reference/t5x.partitioning.rst +7 -0
  26. t5x-main/docs/api_reference/t5x.state_utils.rst +7 -0
  27. t5x-main/docs/api_reference/t5x.test_utils.rst +7 -0
  28. t5x-main/docs/api_reference/t5x.train.rst +7 -0
  29. t5x-main/docs/api_reference/t5x.train_state.rst +7 -0
  30. t5x-main/docs/api_reference/t5x.trainer.rst +7 -0
  31. t5x-main/docs/api_reference/t5x.utils.rst +7 -0
  32. t5x-main/docs/conf.py +132 -0
  33. t5x-main/docs/conf_sphinx_patch.py +202 -0
  34. t5x-main/docs/contributions.md +64 -0
  35. t5x-main/docs/index.md +65 -0
  36. t5x-main/docs/index.rst +24 -0
  37. t5x-main/docs/models.md +318 -0
  38. t5x-main/docs/overview.md +2 -0
  39. t5x-main/docs/requirements.txt +8 -0
  40. t5x-main/docs/t5x.png +3 -0
  41. t5x-main/docs/tutorials.md +51 -0
  42. t5x-main/docs/usage/auxiliary.md +204 -0
  43. t5x-main/docs/usage/decoding.md +199 -0
  44. t5x-main/docs/usage/eval.md +226 -0
  45. t5x-main/docs/usage/finetune.md +286 -0
  46. t5x-main/docs/usage/gin.md +395 -0
  47. t5x-main/docs/usage/gpu-usage.md +87 -0
  48. t5x-main/docs/usage/index.rst +16 -0
  49. t5x-main/docs/usage/infer-files.md +217 -0
  50. t5x-main/docs/usage/infer-seqio.md +241 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ t5x-main/docs/t5x.png filter=lfs diff=lfs merge=lfs -text
37
+ t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00001-of-00002 filter=lfs diff=lfs merge=lfs -text
38
+ t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.meta filter=lfs diff=lfs merge=lfs -text
39
+ t5x-main/t5x/testdata/test_t5_tiny.checkpoint_0 filter=lfs diff=lfs merge=lfs -text
t5x-main/.github/workflows/build.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: build
2
+
3
+ on: [push]
4
+
5
+ jobs:
6
+ build:
7
+ runs-on: ubuntu-latest
8
+ steps:
9
+ - uses: actions/checkout@v2
10
+ - name: Set up Python
11
+ uses: actions/setup-python@v4
12
+ with:
13
+ python-version: '3.10.x'
14
+ cache: 'pip'
15
+ cache-dependency-path: setup.py
16
+ - name: Install dependencies
17
+ run: |
18
+ pip install -e .[test] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
19
+ - name: Test with pytest
20
+ run: |
21
+ pytest
22
+ # The below step just reports the success or failure of tests as a "commit status".
23
+ # This is needed for copybara integration.
24
+ - name: Report success or failure as github status
25
+ if: always()
26
+ shell: bash
27
+ run: |
28
+ status="${{ job.status }}"
29
+ lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
30
+ curl -sS --request POST \
31
+ --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
32
+ --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
33
+ --header 'content-type: application/json' \
34
+ --data '{
35
+ "state": "'$lowercase_status'",
36
+ "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
37
+ "description": "'$status'",
38
+ "context": "github-actions/build"
39
+ }'
t5x-main/CONTRIBUTING.md ADDED
@@ -0,0 +1 @@
 
 
1
+ External contributions are not accepted, sorry!
t5x-main/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
t5x-main/README.md ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5X
2
+
3
+ *Go to [T5X ReadTheDocs Documentation Page](https://t5x.readthedocs.io/).*
4
+
5
+ T5X is a modular, composable, research-friendly framework for high-performance,
6
+ configurable, self-service training, evaluation, and inference of sequence
7
+ models (starting with language) at many scales.
8
+
9
+ It is essentially a new and improved implementation of the
10
+ [T5 codebase](https://github.com/google-research/text-to-text-transfer-transformer)
11
+ (based on [Mesh TensorFlow](https://github.com/tensorflow/mesh)) in [JAX](https://github.com/google/jax) and [Flax](https://github.com/google/flax). To learn
12
+ more, see the [T5X Paper](https://arxiv.org/abs/2203.17189).
13
+
14
+ Below is a quick start guide for training models with TPUs on Google Cloud. For
15
+ additional tutorials and background, see the [complete documentation](docs/index.md).
16
+
17
+ ## Quickstart (Recommended)
18
+
19
+ T5X can be run with [XManager](https://github.com/deepmind/xmanager) on
20
+ [Vertex AI](https://cloud.google.com/vertex-ai). Vertex AI is a platform for
21
+ training that creates TPU instances and runs code on the TPUs. Vertex AI will
22
+ also shut down the TPUs when the jobs terminate. This is signifcantly easier
23
+ than managing GCE VMs and TPU VM instances.
24
+
25
+ 1. Follow the pre-requisites and directions to install [XManager](https://github.com/deepmind/xmanager).
26
+
27
+ 2. Request TPU quota as required. GCP projects come with 8 cores by default,
28
+ which is enough to run one training experiment on a single TPU host. If you want
29
+ to run multi-host training or run multiple trials in parallel, you will need
30
+ more quota. Navigate to [Quotas](https://console.cloud.google.com/quotas).
31
+
32
+ The quota you want is:
33
+
34
+ * Service: `Vertex AI API`
35
+ * Dimensions (location): `us-central1`
36
+ * If you want to run single-host experiments:
37
+ * `Custom model training TPU V2 cores per region`
38
+ * `Custom model training TPU V3 cores per region`
39
+ * If you want to run multi-host experiments:
40
+ * `Custom model training TPU V2 pod cores per region`
41
+ * `Custom model training TPU V3 pod cores per region`
42
+
43
+ TIP: You won't be able to run single-host experiments with multi-host quota.
44
+ (i.e. you can't run `tpu_v2=8` using `TPU V2 pod`)
45
+
46
+
47
+ 3. Launch the xmanager script located at `t5x/scripts/xm_launch.py`.
48
+
49
+ As a running example, we use the WMT14 En-De translation which is described in
50
+ more detail in the Examples section below.
51
+
52
+ ```sh
53
+ export GOOGLE_CLOUD_BUCKET_NAME=...
54
+ export TFDS_DATA_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/data
55
+ export MODEL_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/$(date +%Y%m%d)
56
+
57
+ # Pre-download dataset in multi-host experiments.
58
+ tfds build wmt_t2t_translate --data_dir=$TFDS_DATA_DIR
59
+
60
+ git clone https://github.com/google-research/t5x
61
+ cd ./t5x/
62
+
63
+ python3 ./t5x/scripts/xm_launch.py \
64
+ --gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin \
65
+ --model_dir=$MODEL_DIR \
66
+ --tfds_data_dir=$TFDS_DATA_DIR
67
+ ```
68
+
69
+ Check `gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/` for the output artifacts, which can
70
+ be read by TensorBoard.
71
+
72
+ ## GPU Usage
73
+ Note: NVIDIA has released an updated version of this repository with H100 FP8 support and broad GPU performance improvements. Please visit the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository for more details and usage instructions.
74
+
75
+ T5X can be run easily on GPUs either in single-node configurations or multi-node configurations with a SLURM+pyxis cluster. Further instructions at [t5x/contrib/gpu](https://github.com/google-research/t5x/blob/main/t5x/contrib/gpu/README.md). The `t5x/contrib/gpu/scripts_gpu` folder contains example scripts for pretraining T5X on [The Pile](https://pile.eleuther.ai/) and for finetuning on SQuAD and MNLI. These scripts and associated `gin` configurations also contain additional GPU optimizations for better throughput. More examples and instructions can be found in the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository maintained by NVIDIA with H100 FP8 support and broad GPU performance improvements.
76
+
77
+
78
+ ## Installation
79
+
80
+ Note that all the commands in this document should be run in the commandline of
81
+ the TPU VM instance unless otherwise stated.
82
+
83
+ 1. Follow the
84
+ [instructions](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_the_google_cloud_sdk)
85
+ to set up a Google Cloud Platform (GCP) account and enable the Cloud TPU
86
+ API.
87
+
88
+ **Note:** T5X also works with GPU, please follow instructions in [t5x/contrib/gpu](https://github.com/google-research/t5x/blob/main/t5x/contrib/gpu/README.md) if you'd like to use GPU version.
89
+
90
+ 2. Create a
91
+ [Cloud TPU VM instance](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms)
92
+ following
93
+ [this instruction](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#create-vm).
94
+ We recommend that you develop your workflow in a single v3-8 TPU (i.e.,
95
+ `--accelerator-type=v3-8`) and scale up to pod slices once the pipeline is
96
+ ready. In this README, we focus on using a single v3-8 TPU. See
97
+ [here](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) to
98
+ learn more about TPU architectures.
99
+
100
+ 3. With Cloud TPU VMs, you ssh directly into the host machine of the TPU VM.
101
+ You can install packages, run your code run, etc. in the host machine. Once
102
+ the TPU instance is created, ssh into it with
103
+
104
+ ```sh
105
+ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE}
106
+ ```
107
+
108
+ where `TPU_NAME` and `ZONE` are the name and the zone used in step 2.
109
+
110
+ 4. Install T5X and the dependencies.
111
+
112
+ ```sh
113
+ git clone --branch=main https://github.com/google-research/t5x
114
+ cd t5x
115
+
116
+ python3 -m pip install -e '.[tpu]' -f \
117
+ https://storage.googleapis.com/jax-releases/libtpu_releases.html
118
+
119
+ ```
120
+
121
+
122
+ 5. Create Google Cloud Storage (GCS) bucket to store the dataset and model
123
+ checkpoints. To create a GCS bucket, see these
124
+ [instructions](https://cloud.google.com/storage/docs/creating-buckets).
125
+
126
+ 6. (optional) If you prefer working with Jupyter/Colab style environment
127
+ you can setup a custom Colab runtime by following steps from
128
+ [t5x/notebooks](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md).
129
+
130
+ ## Example: English to German translation
131
+
132
+ As a running example, we use the WMT14 En-De translation. The raw dataset is
133
+ available in TensorFlow Datasets as
134
+ ["wmt_t2t_translate"](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate).
135
+
136
+ T5 casts the translation task such as the following
137
+
138
+ ```py
139
+ {'en': 'That is good.', 'de': 'Das ist gut.'}
140
+ ```
141
+
142
+ to the form called "text-to-text":
143
+
144
+ ```py
145
+ {'inputs': 'translate English to German: That is good.', 'targets': 'Das ist gut.'}
146
+ ```
147
+
148
+ This formulation allows many different classes of language tasks to be expressed
149
+ in a uniform manner and a single encoder-decoder architecture can handle them
150
+ without any task-specific parameters. For more detail, refer to the [T5 paper
151
+ (Raffel et al. 2019)][t5_paper].
152
+
153
+ For a scalable data pipeline and an evaluation framework, we use
154
+ [`SeqIO`](https://github.com/google/seqio), which was factored out of the [T5
155
+ library][t5_github]. A `seqio.Task` packages together the raw dataset, vocabulary,
156
+ preprocessing such as tokenization and evaluation metrics such as
157
+ [BLEU](https://aclanthology.org/P02-1040.pdf) and provides a
158
+ [`tf.data`](https://www.tensorflow.org/guide/data) instance.
159
+
160
+ [The T5 library][t5_github] provides a number of `seqio.Task`s that were used in the
161
+ [T5 paper][t5_paper]. In this example, we use [wmt_t2t_ende_v003](https://github.com/google-research/text-to-text-transfer-transformer/blob/d81c0bab2a41b4d5dfbe4971de32f7d67df65f31/t5/data/tasks.py#L212).
162
+
163
+ Before training or fine-tuning you need to download ["wmt_t2t_translate"]
164
+ (https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate) dataset first.
165
+
166
+ ```sh
167
+ # Data dir to save the processed dataset in "gs://data_dir" format.
168
+ TFDS_DATA_DIR="..."
169
+
170
+ # Make sure that dataset package is up-to-date.
171
+ python3 -m pip install --upgrade tfds-nightly
172
+
173
+ # Pre-download dataset.
174
+ tfds build wmt_t2t_translate ${TFDS_DATA_DIR}
175
+ ```
176
+
177
+ ### Training
178
+
179
+ To run a training job, we use the `t5x/train.py` script.
180
+
181
+ ```sh
182
+ # Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
183
+ MODEL_DIR="..."
184
+ T5X_DIR="..." # directory where the T5X repo is cloned.
185
+ TFDS_DATA_DIR="..."
186
+
187
+ python3 ${T5X_DIR}/t5x/train.py \
188
+ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
189
+ --gin.MODEL_DIR=\"${MODEL_DIR}\" \
190
+ --tfds_data_dir=${TFDS_DATA_DIR}
191
+ ```
192
+
193
+ The configuration for this training run is defined in the Gin file
194
+ [base_wmt_from_scratch.gin](t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin).
195
+ [Gin-config](https://github.com/google/gin-config) is a library to handle
196
+ configurations based on dependency injection. Among many benefits, Gin allows
197
+ users to pass custom components such as a custom model to the T5X library
198
+ without having to modify the core library. The [custom
199
+ components](#custom-components) section shows how this is done.
200
+
201
+ While the core library is independent of Gin, it is central to the examples we
202
+ provide. Therefore, we provide a short [introduction][gin-primer] to Gin in the
203
+ context of T5X. All the configurations are written to a file "config.gin" in
204
+ `MODEL_DIR`. This makes debugging as well as reproducing the experiment much
205
+ easier.
206
+
207
+ In addition to the `config.json`, `model-info.txt` file summarizes the model
208
+ parameters (shape, names of the axes, partitioning info) as well as the
209
+ optimizer states.
210
+
211
+
212
+
213
+ #### TensorBoard
214
+
215
+ To monitor the training in [TensorBoard](https://www.tensorflow.org/tensorboard), it is much easier (due to
216
+ authentification issues) to launch the TensorBoard on your own machine and _not_ in
217
+ the TPU VM. So in the commandline where you ssh'ed into the TPU VM, launch the
218
+ TensorBoard with the `logdir` pointing to the `MODEL_DIR`.
219
+
220
+ ```sh
221
+ # NB: run this on your machine not TPU VM!
222
+ MODEL_DIR="..." # Copy from the TPU VM.
223
+ tensorboard --logdir=${MODEL_DIR}
224
+ ```
225
+
226
+ Or you can launch the TensorBoard inside a Colab. In a Colab cell, run
227
+
228
+ ```python
229
+ from google.colab import auth
230
+ auth.authenticate_user()
231
+ ```
232
+
233
+ to authorize the Colab to access the GCS bucket and launch the TensorBoard.
234
+
235
+ ```python
236
+ %load_ext tensorboard
237
+ model_dir = "..." # Copy from the TPU VM.
238
+ %tensorboard --logdir=model_dir
239
+ ```
240
+
241
+
242
+ ### Fine-tuning
243
+
244
+ We can leverage the benefits of self-supervised pre-training by initializing
245
+ from one of our pre-trained models. Here we use the T5.1.1 Base checkpoint.
246
+
247
+ ```sh
248
+ # Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
249
+ MODEL_DIR="..."
250
+
251
+ # Data dir to save the processed dataset in "gs://data_dir" format.
252
+ TFDS_DATA_DIR="..."
253
+ T5X_DIR="..." # directory where the T5X repo is cloned.
254
+
255
+ python3 ${T5X_DIR}/t5x/train.py \
256
+ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_finetune.gin" \
257
+ --gin.MODEL_DIR=\"${MODEL_DIR}\" \
258
+ --tfds_data_dir=${TFDS_DATA_DIR}
259
+ ```
260
+
261
+ **Note:** when supplying a string, dict, list, tuple value, or a bash variable
262
+ via a flag, you must put it in quotes. In the case of strings, it requires
263
+ escaped quotes (`\"<string>\"`). For example:
264
+ `--gin.utils.DatasetConfig.split=\"validation\"` or
265
+ `--gin.MODEL_DIR=\"${MODEL_DIR}\"`.
266
+
267
+ Gin makes it easy to change a number of configurations. For example, you can
268
+ change the `partitioning.PjitPartitioner.num_partitions` (overriding
269
+ the value in
270
+ [base_wmt_from_scratch.gin](t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin))
271
+ to chanage the parallelism strategy and pass it as a commandline arg.
272
+
273
+ ```sh
274
+ --gin.partitioning.PjitPartitioner.num_partitions=8
275
+ ```
276
+
277
+ ### Evaluation
278
+
279
+ To run the offline (i.e. without training) evaluation, you can use `t5x/eval.py`
280
+ script.
281
+
282
+ ```sh
283
+ EVAL_OUTPUT_DIR="..." # directory to write eval output
284
+ T5X_DIR="..." # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
285
+ TFDS_DATA_DIR="..."
286
+ CHECKPOINT_PATH="..."
287
+
288
+ python3 ${T5X_DIR}/t5x/eval.py \
289
+ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_eval.gin" \
290
+ --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
291
+ --gin.EVAL_OUTPUT_DIR=\"${EVAL_OUTPUT_DIR}\" \
292
+ --tfds_data_dir=${TFDS_DATA_DIR}
293
+ ```
294
+
295
+
296
+ ### Inference
297
+
298
+ To run inference, you can use `t5x/infer.py` script. Here we use the same
299
+ `seqio.Task`, but for inference we do not use the targets features other than
300
+ logging them alongside the prediction in a JSON file.
301
+
302
+ ```sh
303
+ INFER_OUTPUT_DIR="..." # directory to write infer output
304
+ T5X_DIR="..." # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
305
+ TFDS_DATA_DIR="..."
306
+ CHECKPOINT_PATH="..."
307
+
308
+ python3 ${T5X_DIR}/t5x/infer.py \
309
+ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_infer.gin" \
310
+ --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
311
+ --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
312
+ --tfds_data_dir=${TFDS_DATA_DIR}
313
+ ```
314
+
315
+ ### Exporting as TensorFlow Saved Model
316
+
317
+ Pretrained model can be exported as TensorFlow Saved Model, and deployed
318
+ to Vertex AI Prediction service using [Optimized TensorFlow Runtime]
319
+ (https://cloud.google.com/vertex-ai/docs/predictions/optimized-tensorflow-runtime).
320
+ Please note that exported model won't work on OSS based
321
+ [TensorFlow Model Server](https://github.com/tensorflow/serving).
322
+
323
+ ```sh
324
+ T5X_DIR="..." # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
325
+ CHECKPOINT_PATH="..."
326
+
327
+ BATCH_SIZE=None
328
+ BEAM_SIZE=1
329
+
330
+ # Use 'bfloat16' if you plan to run exported model on NVIDIA A100 or newer GPUs,
331
+ # for other GPUs use 'float32'.
332
+ ACTIVATION_DTYPE=bfloat16
333
+
334
+ # Version numbers must be numeric. We generate one based on datetime.
335
+ VERSION=$(date +%Y%m%d%H%M%S)
336
+
337
+ NAME=t5x_base_${ACTIVATION_DTYPE} # Model name.
338
+
339
+ # Path to export model to. Note that export script is going to add _cpu suffix
340
+ # after model name.
341
+ OUTPUT=${CHECKPOINT_PATH}/saved_model.${NAME}/${VERSION}
342
+
343
+ declare -a ARGS=(
344
+ --gin_file=t5x/examples/t5/t5_1_1/base.gin
345
+ --gin_file=t5x/t5x/configs/runs/export.gin
346
+ --gin.TASK_FEATURE_LENGTHS="{'inputs': 256, 'targets': 256}"
347
+ --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\"
348
+ --gin.MODEL_NAME=\"/ml/${USER}/t5x_base\"
349
+ --gin.MODEL_OUTPUT_DIR=\"${OUTPUT}\"
350
+ --gin.BEAM_SIZE=${BEAM_SIZE}
351
+ --gin.BATCH_SIZE=${BATCH_SIZE}
352
+ --gin.export_lib.save.partitioner=None
353
+ --gin.export_lib.save.warmup_examples="['hello world']"
354
+ --gin.export_lib.ExportableModule.use_batch_function=False
355
+ --gin.export_lib.ExportableModule.use_gpu=False
356
+ --gin.export_lib.ExportableModule.jit_compile=False
357
+ --gin.ACTIVATION_DTYPE=\"${ACTIVATION_DTYPE}\"
358
+ --gin.network.T5Config.dtype=\"${ACTIVATION_DTYPE}\"
359
+ --gin.utils.RestoreCheckpointConfig.dtype=\"${ACTIVATION_DTYPE}\"
360
+ --gin.DROPOUT_RATE=0.0
361
+ )
362
+
363
+ (python3 ${T5X_DIR}/t5x/export.py "${ARGS[@]}")
364
+ ```
365
+
366
+ For detailed arguments definition refer to [export.gin]
367
+ (t5x/configs/runs/export.gin).
368
+
369
+ You can run XL and smaller models on NVIDIA A100 40GB, and XXL models on
370
+ NVIDIA A100 80GB.
371
+
372
+ ## Custom components
373
+
374
+ [The translation example](#example-english-to-german-translation) uses the
375
+ encoder-decoder model that T5X provides as well as the dataset from the T5
376
+ library. This section shows how you can use your own dataset and a model and
377
+ pass via Gin.
378
+
379
+ ### Example: custom dataset in a user directory
380
+
381
+ For this example, we have the following directory structure with
382
+ `${HOME}/dir1/user_dir` representing a user directory with custom components.
383
+
384
+ ```
385
+ ${HOME}
386
+ └── dir1
387
+    └── user_dir
388
+    ├── t5_1_1_base_de_en.gin
389
+    └── tasks.py
390
+ ```
391
+
392
+ As an example, let's define a new dataset. Here we use the same Translation
393
+ dataset but we define the translation task in the opposite direction, i.e.,
394
+ German to English intead of English to German. We define this task in `tasks.py`
395
+
396
+ ```py
397
+ # ${HOME}/dir1/user_dir/tasks.py
398
+
399
+ import functools
400
+ import seqio
401
+ import tensorflow_datasets as tfds
402
+ from t5.evaluation import metrics
403
+ from t5.data import preprocessors
404
+
405
+ vocabulary = seqio.SentencePieceVocabulary(
406
+ 'gs://t5-data/vocabs/cc_all.32000/sentencepiece.model', extra_ids=100)
407
+ output_features = {
408
+ 'inputs': seqio.Feature(vocabulary=vocabulary),
409
+ 'targets': seqio.Feature(vocabulary=vocabulary)
410
+ }
411
+
412
+ seqio.TaskRegistry.add(
413
+ 'wmt_t2t_de_en_v003',
414
+ source=seqio.TfdsDataSource(tfds_name='wmt_t2t_translate/de-en:1.0.0'),
415
+ preprocessors=[
416
+ functools.partial(
417
+ preprocessors.translate,
418
+ source_language='de', target_language='en'),
419
+ seqio.preprocessors.tokenize,
420
+ seqio.CacheDatasetPlaceholder(),
421
+ seqio.preprocessors.append_eos_after_trim,
422
+ ],
423
+ metric_fns=[metrics.bleu],
424
+ output_features=output_features)
425
+ ```
426
+
427
+ In the Gin file, most of the settings are equivalent to those used in the
428
+ [En->De example](#example-english-to-german-translation). So we include the Gin
429
+ file from that example. To use "wmt_t2t_de_en_v003" task we just defined, we
430
+ need to import the task module "tasks.py". Note that we use a relative path
431
+ defined with respect to the user directory. This will be specified as a
432
+ flag.
433
+
434
+ ```py
435
+ # ${HOME}/dir1/user_dir/t5_1_1_base_de_en.gin
436
+ from __gin__ import dynamic_registration
437
+ import tasks # This imports the task defined in dir1/user_dir/tasks.py.
438
+
439
+ include "t5x-tmp/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin"
440
+ MIXTURE_OR_TASK_NAME = "wmt_t2t_de_en_v003"
441
+ ```
442
+
443
+ Finally, we launch training passing the user directory as a flag
444
+ `gin_search_paths` such that the Gin file and python modules can be specified
445
+ with relative paths.
446
+
447
+ ```sh
448
+ PROJECT_DIR=${HOME}"/dir1/user_dir"
449
+ T5X_DIR="..." # directory where the t5x is cloned.
450
+ TFDS_DATA_DIR="..."
451
+ MODEL_DIR="..."
452
+ export PYTHONPATH=${PROJECT_DIR}
453
+
454
+ python3 ${T5X_DIR}/t5x/train.py \
455
+ --gin_search_paths=${PROJECT_DIR} \
456
+ --gin_file="t5_1_1_base_de_en.gin" \
457
+ --gin.MODEL_DIR=\"${MODEL_DIR}\" \
458
+ --tfds_data_dir=${TFDS_DATA_DIR}
459
+ ```
460
+
461
+ ## Checkpoints
462
+
463
+ ### Native Checkpoints
464
+
465
+ We have released the checkpoints of many of the original T5 models and their
466
+ variants a native T5X format for maximal efficiency.
467
+ See the [complete list](https://github.com/google-research/t5x/blob/main/docs/models.md) including the
468
+ matching Gin configuration files.
469
+
470
+ These are converted from the public [Mesh TensorFlow
471
+ checkpoints](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511)
472
+ .
473
+
474
+
475
+ ### Compatibility with the Mesh TensorFlow checkpoints
476
+ The Mesh TensorFlow checkpoints trained using the [T5 library][t5_github] can be
477
+ directly loaded into T5X. For example, we can rerun the fine-tuning example
478
+ initializing from the MTF checkpoint by changing the `INIT_CHECKPOINT` Gin
479
+ macro.
480
+
481
+ ```sh
482
+ # Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
483
+ MODEL_DIR="..."
484
+
485
+ # Data dir to save the processed dataset in "gs://data_dir" format.
486
+ TFDS_DATA_DIR="..."
487
+ T5X_DIR="..." # directory where the T5X repo is cloned.
488
+
489
+ python3 ${T5X_DIR}/t5x/train.py \
490
+ --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt19_ende_train.gin" \
491
+ --gin.MODEL_DIR=\"${MODEL_DIR}\" \
492
+ --gin.MIXTURE_OR_TASK_NAME=\"wmt_t2t_ende_v003\" \
493
+ --gin.INIT_CHECKPOINT=\"gs://t5-data/pretrained_models/t5.1.1.base/model.ckpt-1000000\" \
494
+ --tfds_data_dir=${TFDS_DATA_DIR}
495
+ ```
496
+
497
+ Note that restoring directly from the Mesh TensorFlow checkpoints can be
498
+ inefficient if heavy model parallelism is used for large models. This is
499
+ because each host loads the entire copy of the model first and then keep only
500
+ the relevant slices dictated by the model parallelism specification. If you have
501
+ Mesh TensorFlow checkpoints that you run often, we recommend converting the
502
+ checkpoints to T5X native format using the
503
+ [convert_tf_checkpoint script](t5x/scripts/convert_tf_checkpoint.py).
504
+
505
+
506
+ ## Citing T5X
507
+ Please use the following bibtex entry to cite T5X.
508
+
509
+ ```
510
+ @article{roberts2022t5x,
511
+ url = {https://arxiv.org/abs/2203.17189},
512
+ author = {Roberts, Adam and Chung, Hyung Won and Levskaya, Anselm and Mishra, Gaurav and Bradbury, James and Andor, Daniel and Narang, Sharan and Lester, Brian and Gaffney, Colin and Mohiuddin, Afroz and Hawthorne, Curtis and Lewkowycz, Aitor and Salcianu, Alex and van Zee, Marc and Austin, Jacob and Goodman, Sebastian and Soares, Livio Baldini and Hu, Haitang and Tsvyashchenko, Sasha and Chowdhery, Aakanksha and Bastings, Jasmijn and Bulian, Jannis and Garcia, Xavier and Ni, Jianmo and Chen, Andrew and Kenealy, Kathleen and Clark, Jonathan H. and Lee, Stephan and Garrette, Dan and Lee-Thorp, James and Raffel, Colin and Shazeer, Noam and Ritter, Marvin and Bosma, Maarten and Passos, Alexandre and Maitin-Shepard, Jeremy and Fiedel, Noah and Omernick, Mark and Saeta, Brennan and Sepassi, Ryan and Spiridonov, Alexander and Newlan, Joshua and Gesmundo, Andrea},
513
+ title = {Scaling Up Models and Data with $\texttt{t5x}$ and $\texttt{seqio}$},
514
+ journal={arXiv preprint arXiv:2203.17189},
515
+ year = {2022},
516
+ }
517
+ ```
518
+
519
+
520
+ ## Note
521
+ This is not an officially supported Google product
522
+
523
+ [t5_paper]: https://arxiv.org/abs/1910.10683
524
+ [t5_github]: https://github.com/google-research/text-to-text-transfer-transformer
525
+ [gin-primer]: docs/usage/gin.md
t5x-main/docs/_static/t5x_theme.css ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @import url("theme.css");
2
+
3
+ .wy-nav-content {
4
+ max-width: 1290px;
5
+ }
6
+
7
+ .rst-content table.docutils {
8
+ width: 100%;
9
+ }
10
+
11
+ .rst-content table.docutils td {
12
+ vertical-align: top;
13
+ padding: 0;
14
+ }
15
+
16
+ .rst-content table.docutils td p {
17
+ padding: 8px;
18
+ }
19
+
20
+ .rst-content div[class^=highlight] {
21
+ border: 0;
22
+ margin: 0;
23
+ }
t5x-main/docs/_templates/autosummary/t5x_module.rst ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {{ fullname | escape | underline}}
2
+
3
+ .. currentmodule:: {{ module }}
4
+
5
+ .. autoclass:: {{ objname }}
6
+ :exclude-members:
7
+
8
+ {% block methods %}
9
+
10
+ .. automethod:: __call__
11
+
12
+ {% if methods %}
13
+ .. rubric:: Methods
14
+
15
+ .. autosummary::
16
+
17
+ {% for item in methods %}
18
+ {%- if item not in inherited_members and item not in annotations and not item in ['__init__'] %}
19
+ ~{{ name }}.{{ item }}
20
+ {%- endif %}
21
+ {%- endfor %}
22
+ {% endif %}
23
+ {% endblock %}
t5x-main/docs/api_reference/index.rst ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ API Reference
2
+ =============
3
+
4
+ Binaries
5
+ --------
6
+
7
+ .. toctree::
8
+ :maxdepth: 3
9
+
10
+ t5x.train
11
+ t5x.infer
12
+ t5x.eval
13
+ t5x.main
14
+
15
+ Training
16
+ ---------
17
+
18
+ .. toctree::
19
+ :maxdepth: 3
20
+
21
+ t5x.trainer
22
+ t5x.optimizers
23
+ t5x.interactive_model
24
+ t5x.train_state
25
+ t5x.state_utils
26
+ t5x.losses
27
+ t5x.metrics
28
+ t5x.utils
29
+ t5x.adafactor
30
+
31
+ Inference
32
+ ---------
33
+
34
+ .. toctree::
35
+ :maxdepth: 3
36
+
37
+ t5x.decoding
38
+
39
+ Models
40
+ ------
41
+
42
+ .. toctree::
43
+ :maxdepth: 3
44
+
45
+ t5x.models
46
+
47
+ Checkpointing
48
+ -------------
49
+
50
+ .. toctree::
51
+ :maxdepth: 3
52
+
53
+ t5x.checkpoints
54
+ t5x.checkpoint_utils
55
+ t5x.checkpoint_importer
56
+
57
+
58
+ Paritioning
59
+ -----------
60
+
61
+ .. toctree::
62
+ :maxdepth: 3
63
+
64
+ t5x.partitioning
65
+
66
+ Config
67
+ ------
68
+
69
+ .. toctree::
70
+ :maxdepth: 3
71
+
72
+ t5x.config_utils
73
+ t5x.gin_utils
74
+
75
+ Utils
76
+ -----
77
+
78
+ .. toctree::
79
+ :maxdepth: 3
80
+
81
+ t5x.test_utils
82
+ t5x.binary_search
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
t5x-main/docs/api_reference/t5x.adafactor.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.adafactor package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.adafactor
5
+
6
+ .. automodule:: t5x.adafactor
7
+ :members:
t5x-main/docs/api_reference/t5x.binary_search.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.binary_search package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.binary_search
5
+
6
+ .. automodule:: t5x.binary_search
7
+ :members:
t5x-main/docs/api_reference/t5x.checkpoint_importer.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.checkpoint_importer package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.checkpoint_importer
5
+
6
+ .. automodule:: t5x.checkpoint_importer
7
+ :members:
t5x-main/docs/api_reference/t5x.checkpoint_utils.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.checkpoint_utils package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.checkpoint_utils
5
+
6
+ .. automodule:: t5x.checkpoint_utils
7
+ :members:
t5x-main/docs/api_reference/t5x.checkpoints.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.checkpoints package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.checkpoints
5
+
6
+ .. automodule:: t5x.checkpoints
7
+ :members:
t5x-main/docs/api_reference/t5x.config_utils.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.config_utils package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.config_utils
5
+
6
+ .. automodule:: t5x.config_utils
7
+ :members:
t5x-main/docs/api_reference/t5x.decoding.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.decoding package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.decoding
5
+
6
+ .. automodule:: t5x.decoding
7
+ :members:
t5x-main/docs/api_reference/t5x.eval.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.eval binary
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.eval
5
+
6
+ .. automodule:: t5x.eval
7
+ :members:
t5x-main/docs/api_reference/t5x.gin_utils.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.gin_utils package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.gin_utils
5
+
6
+ .. automodule:: t5x.gin_utils
7
+ :members:
t5x-main/docs/api_reference/t5x.infer.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.infer binary
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.infer
5
+
6
+ .. automodule:: t5x.infer
7
+ :members:
t5x-main/docs/api_reference/t5x.interactive_model.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.interactive_model package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.interactive_model
5
+
6
+ .. automodule:: t5x.interactive_model
7
+ :members:
t5x-main/docs/api_reference/t5x.losses.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.losses package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.losses
5
+
6
+ .. automodule:: t5x.losses
7
+ :members:
t5x-main/docs/api_reference/t5x.main.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.main binary
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.main
5
+
6
+ .. automodule:: t5x.main
7
+ :members:
t5x-main/docs/api_reference/t5x.metrics.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.metrics package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.metrics
5
+
6
+ .. automodule:: t5x.metrics
7
+ :members:
t5x-main/docs/api_reference/t5x.models.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.models package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.models
5
+
6
+ .. automodule:: t5x.models
7
+ :members:
t5x-main/docs/api_reference/t5x.optimizers.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.optimizers package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.optimizers
5
+
6
+ .. automodule:: t5x.optimizers
7
+ :members:
t5x-main/docs/api_reference/t5x.partitioning.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.partitioning package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.partitioning
5
+
6
+ .. automodule:: t5x.partitioning
7
+ :members:
t5x-main/docs/api_reference/t5x.state_utils.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.state_utils package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.state_utils
5
+
6
+ .. automodule:: t5x.state_utils
7
+ :members:
t5x-main/docs/api_reference/t5x.test_utils.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.test_utils package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.test_utils
5
+
6
+ .. automodule:: t5x.test_utils
7
+ :members:
t5x-main/docs/api_reference/t5x.train.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.train binary
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.train
5
+
6
+ .. automodule:: t5x.train
7
+ :members:
t5x-main/docs/api_reference/t5x.train_state.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.train_state package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.train_state
5
+
6
+ .. automodule:: t5x.train_state
7
+ :members:
t5x-main/docs/api_reference/t5x.trainer.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.trainer package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.trainer
5
+
6
+ .. automodule:: t5x.trainer
7
+ :members:
t5x-main/docs/api_reference/t5x.utils.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ t5x.utils package
2
+ ========================
3
+
4
+ .. currentmodule:: t5x.utils
5
+
6
+ .. automodule:: t5x.utils
7
+ :members:
t5x-main/docs/conf.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Configuration file for the Sphinx documentation builder.
16
+
17
+ This file only contains a selection of the most common options. For a full
18
+ list see the documentation:
19
+ https://www.sphinx-doc.org/en/master/usage/configuration.html
20
+ """
21
+
22
+ # pylint:disable=all
23
+ # -- Path setup --------------------------------------------------------------
24
+
25
+ # If extensions (or modules to document with autodoc) are in another directory,
26
+ # add these directories to sys.path here. If the directory is relative to the
27
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
28
+ #
29
+ import os
30
+ import sys
31
+
32
+ sys.path.insert(0, os.path.abspath('..'))
33
+
34
+ # patch sphinx
35
+ import docs.conf_sphinx_patch
36
+
37
+ # -- Project information -----------------------------------------------------
38
+
39
+ project = 'T5X'
40
+ copyright = '2023, The T5X authors' # pylint: disable=redefined-builtin
41
+ author = 'The T5X authors'
42
+
43
+ # -- General configuration ---------------------------------------------------
44
+
45
+ # Add any Sphinx extension module names here, as strings. They can be
46
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
47
+ # ones.
48
+ extensions = [
49
+ 'sphinx.ext.autodoc',
50
+ 'sphinx.ext.autosummary',
51
+ 'sphinx.ext.autosectionlabel',
52
+ 'sphinx.ext.doctest',
53
+ 'sphinx.ext.intersphinx',
54
+ 'sphinx.ext.mathjax',
55
+ 'sphinx.ext.napoleon',
56
+ 'sphinx.ext.viewcode',
57
+ 'myst_nb',
58
+ 'sphinx_design',
59
+ ]
60
+
61
+ # The suffix(es) of source filenames.
62
+ # You can specify multiple suffix as a list of string:
63
+ #
64
+ source_suffix = ['.rst', '.ipynb', '.md']
65
+
66
+ autosummary_generate = True
67
+
68
+ master_doc = 'index'
69
+
70
+ autodoc_typehints = 'none'
71
+
72
+ # Add any paths that contain templates here, relative to this directory.
73
+ templates_path = ['_templates']
74
+
75
+ # List of patterns, relative to source directory, that match files and
76
+ # directories to ignore when looking for source files.
77
+ # This pattern also affects html_static_path and html_extra_path.
78
+ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
79
+
80
+ # -- Options for HTML output -------------------------------------------------
81
+
82
+ # The theme to use for HTML and HTML Help pages. See the documentation for
83
+ # a list of builtin themes.
84
+ #
85
+ # html_theme = 'pydata_sphinx_theme'
86
+ html_theme = 'sphinx_book_theme'
87
+ html_css_files = ['css/t5x_theme.css']
88
+
89
+ # The name of an image file (relative to this directory) to place at the top
90
+ # of the sidebar.
91
+ html_logo = './t5x.png'
92
+ html_favicon = './t5x.png'
93
+
94
+ # title of the website
95
+ html_title = ''
96
+
97
+ # Add any paths that contain custom static files (such as style sheets) here,
98
+ # relative to this directory. They are copied after the builtin static files,
99
+ # so a file named 'default.css' will overwrite the builtin 'default.css'.
100
+ html_static_path = ['_static']
101
+
102
+ html_theme_options = {
103
+ 'repository_url': 'https://github.com/google-research/t5x',
104
+ 'use_repository_button': True, # add a 'link to repository' button
105
+ 'use_issues_button': False, # add an 'Open an Issue' button
106
+ 'path_to_docs': (
107
+ 'docs'
108
+ ), # used to compute the path to launch notebooks in colab
109
+ 'launch_buttons': {
110
+ 'colab_url': 'https://colab.research.google.com/',
111
+ },
112
+ 'prev_next_buttons_location': None,
113
+ 'show_navbar_depth': 1,
114
+ }
115
+
116
+ # -- Options for myst ----------------------------------------------
117
+ # uncomment line below to avoid running notebooks during development
118
+ # nb_execution_mode = 'off'
119
+ # Notebook cell execution timeout; defaults to 30.
120
+ nb_execution_timeout = 100
121
+ # List of patterns, relative to source directory, that match notebook
122
+ # files that will not be executed.
123
+ myst_enable_extensions = ['dollarmath']
124
+ # raise exceptions on execution so CI can catch errors
125
+ nb_execution_allow_errors = False
126
+ nb_execution_raise_on_error = True
127
+
128
+ # -- Extension configuration -------------------------------------------------
129
+
130
+ # Tell sphinx-autodoc-typehints to generate stub parameter annotations including
131
+ # types, even if the parameters aren't explicitly documented.
132
+ always_document_param_types = True
t5x-main/docs/conf_sphinx_patch.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Patch Sphinx to improve documentation aesthetics."""
16
+
17
+ # TODO(cgarciae): Send a PR to sphinx to upstream this fix.
18
+ # Issue: https://github.com/google/flax/issues/2196
19
+ # This patch is needed to make autosummary provide the "annotations"
20
+ # variable so we can exclude function attributes from the methods list
21
+ # in flax_module.rst. The patch as such only adds this single line:
22
+ #
23
+ # ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys())'
24
+ #
25
+ # We should consider sending a PR to sphinx so we can get rid of this.
26
+ # Original source:
27
+ # https://github.com/sphinx-doc/sphinx/blob/0aedcc9a916daa92d477226da67d33ce1831822e/sphinx/ext/autosummary/generate.py#L211-L351
28
+ from typing import Any, Dict, List, Set, Tuple
29
+ import sphinx.ext.autodoc
30
+ import sphinx.ext.autosummary.generate as ag
31
+
32
+
33
+ # pylint:disable=all
34
+ def generate_autosummary_content(
35
+ name: str,
36
+ obj: Any,
37
+ parent: Any,
38
+ template: ag.AutosummaryRenderer,
39
+ template_name: str,
40
+ imported_members: bool,
41
+ app: Any,
42
+ recursive: bool,
43
+ context: Dict,
44
+ modname: str = None,
45
+ qualname: str = None,
46
+ ) -> str:
47
+ doc = ag.get_documenter(app, obj, parent)
48
+
49
+ def skip_member(obj: Any, name: str, objtype: str) -> bool:
50
+ try:
51
+ return app.emit_firstresult(
52
+ 'autodoc-skip-member', objtype, name, obj, False, {}
53
+ )
54
+ except Exception as exc:
55
+ ag.logger.warning(
56
+ __(
57
+ 'autosummary: failed to determine %r to be documented, '
58
+ 'the following exception was raised:\n%s'
59
+ ),
60
+ name,
61
+ exc,
62
+ type='autosummary',
63
+ )
64
+ return False
65
+
66
+ def get_class_members(obj: Any) -> Dict[str, Any]:
67
+ members = sphinx.ext.autodoc.get_class_members(
68
+ obj, [qualname], ag.safe_getattr
69
+ )
70
+ return {name: member.object for name, member in members.items()}
71
+
72
+ def get_module_members(obj: Any) -> Dict[str, Any]:
73
+ members = {}
74
+ for name in ag.members_of(obj, app.config):
75
+ try:
76
+ members[name] = ag.safe_getattr(obj, name)
77
+ except AttributeError:
78
+ continue
79
+ return members
80
+
81
+ def get_all_members(obj: Any) -> Dict[str, Any]:
82
+ if doc.objtype == 'module':
83
+ return get_module_members(obj)
84
+ elif doc.objtype == 'class':
85
+ return get_class_members(obj)
86
+ return {}
87
+
88
+ def get_members(
89
+ obj: Any,
90
+ types: Set[str],
91
+ include_public: List[str] = [],
92
+ imported: bool = True,
93
+ ) -> Tuple[List[str], List[str]]:
94
+ items: List[str] = []
95
+ public: List[str] = []
96
+
97
+ all_members = get_all_members(obj)
98
+ for name, value in all_members.items():
99
+ documenter = ag.get_documenter(app, value, obj)
100
+ if documenter.objtype in types:
101
+ # skip imported members if expected
102
+ if imported or getattr(value, '__module__', None) == obj.__name__:
103
+ skipped = skip_member(value, name, documenter.objtype)
104
+ if skipped is True:
105
+ pass
106
+ elif skipped is False:
107
+ # show the member forcedly
108
+ items.append(name)
109
+ public.append(name)
110
+ else:
111
+ items.append(name)
112
+ if name in include_public or not name.startswith('_'):
113
+ # considers member as public
114
+ public.append(name)
115
+ return public, items
116
+
117
+ def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]:
118
+ """Find module attributes with docstrings."""
119
+ attrs, public = [], []
120
+ try:
121
+ analyzer = ag.ModuleAnalyzer.for_module(name)
122
+ attr_docs = analyzer.find_attr_docs()
123
+ for namespace, attr_name in attr_docs:
124
+ if namespace == '' and attr_name in members:
125
+ attrs.append(attr_name)
126
+ if not attr_name.startswith('_'):
127
+ public.append(attr_name)
128
+ except ag.PycodeError:
129
+ pass # give up if ModuleAnalyzer fails to parse code
130
+ return public, attrs
131
+
132
+ def get_modules(obj: Any) -> Tuple[List[str], List[str]]:
133
+ items: List[str] = []
134
+ for _, modname, _ispkg in ag.pkgutil.iter_modules(obj.__path__):
135
+ fullname = name + '.' + modname
136
+ try:
137
+ module = ag.import_module(fullname)
138
+ if module and hasattr(module, '__sphinx_mock__'):
139
+ continue
140
+ except ImportError:
141
+ pass
142
+
143
+ items.append(fullname)
144
+ public = [x for x in items if not x.split('.')[-1].startswith('_')]
145
+ return public, items
146
+
147
+ ns: Dict[str, Any] = {}
148
+ ns.update(context)
149
+
150
+ if doc.objtype == 'module':
151
+ scanner = ag.ModuleScanner(app, obj)
152
+ ns['members'] = scanner.scan(imported_members)
153
+ ns['functions'], ns['all_functions'] = get_members(
154
+ obj, {'function'}, imported=imported_members
155
+ )
156
+ ns['classes'], ns['all_classes'] = get_members(
157
+ obj, {'class'}, imported=imported_members
158
+ )
159
+ ns['exceptions'], ns['all_exceptions'] = get_members(
160
+ obj, {'exception'}, imported=imported_members
161
+ )
162
+ ns['attributes'], ns['all_attributes'] = get_module_attrs(ns['members'])
163
+ ispackage = hasattr(obj, '__path__')
164
+ if ispackage and recursive:
165
+ ns['modules'], ns['all_modules'] = get_modules(obj)
166
+ elif doc.objtype == 'class':
167
+ ns['members'] = dir(obj)
168
+ ns['inherited_members'] = set(dir(obj)) - set(obj.__dict__.keys())
169
+ ns['methods'], ns['all_methods'] = get_members(
170
+ obj, {'method'}, ['__init__']
171
+ )
172
+ ns['attributes'], ns['all_attributes'] = get_members(
173
+ obj, {'attribute', 'property'}
174
+ )
175
+ ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys())
176
+
177
+ if modname is None or qualname is None:
178
+ modname, qualname = ag.split_full_qualified_name(name)
179
+
180
+ if doc.objtype in ('method', 'attribute', 'property'):
181
+ ns['class'] = qualname.rsplit('.', 1)[0]
182
+
183
+ if doc.objtype in ('class',):
184
+ shortname = qualname
185
+ else:
186
+ shortname = qualname.rsplit('.', 1)[-1]
187
+
188
+ ns['fullname'] = name
189
+ ns['module'] = modname
190
+ ns['objname'] = qualname
191
+ ns['name'] = shortname
192
+
193
+ ns['objtype'] = doc.objtype
194
+ ns['underline'] = len(name) * '='
195
+
196
+ if template_name:
197
+ return template.render(template_name, ns)
198
+ else:
199
+ return template.render(doc.objtype, ns)
200
+
201
+
202
+ ag.generate_autosummary_content = generate_autosummary_content
t5x-main/docs/contributions.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributions
2
+
3
+ T5X was developed as part of the T5 Infrastructure effort at Google Research.
4
+
5
+ Adam Roberts founded and leads the project, designed and wrote much of `seqio`
6
+ and `t5x`, and co-authored the
7
+ [T5X and SeqIO paper](https://arxiv.org/abs/2203.17189). Hyung Won Chung
8
+ designed and wrote much of `t5x`, led its open sourcing, and co-authored the
9
+ paper. Anselm Levskaya built the initial prototype for `t5x` and wrote much of
10
+ the code. Gaurav Mishra leads `seqio`, implemented deterministic pipelines, and
11
+ co-authored the paper. James Bradbury implemented partitioning in `t5x` and
12
+ co-wrote the paper.
13
+
14
+ Daniel Andor, Sharan Narang, Brian Lester, Colin Gaffney, Afroz Mohiuddin,
15
+ Curtis Hawthorne, Aitor Lewkowycz, Alex Salcianu, Marc van Zee, Jacob Austin,
16
+ Sebastian Good-man, Livio Baldini Soares, Haitang Hu, Sasha Tsvyashchenko,
17
+ Aakanksha Chowdhery, Jasmijn Bastings, Jannis Bulian, Xavier Garcia, Jianmo Ni,
18
+ Andrew Chen, Kathleen Kenealy, Kehang Han, Jonathan H. Clark, Stephan Lee, Dan
19
+ Garrette, and James Lee-Thorp made substantial code contributions.
20
+
21
+ Colin Raffel and Noam Shazeer helped design `seqio`. Marvin Ritter advised on
22
+ deterministic pipelines and the use of CLU Metrics. Maarten Bosma helped design
23
+ deterministic pipelines. Jeremy Maitin-Shepard advised on the use of
24
+ TensorStore. Alexandre Passos and Ryan Sepassi advised on overall technical
25
+ design.
26
+
27
+ Noah Fiedel is a member of the leadership team, contributed to the high level
28
+ design and roadmap, and co-wrote the paper. Mark Omernick, Brennan Saeta, Ryan
29
+ Sepassi, Alexander Spiridonov (Product Manager), and Josh Newlan (Technical
30
+ Program Manager) are members of the leadership team and co-wrote the paper.
31
+ Andrea Gesmundo is a member of the leadership team and contributed to the
32
+ internal infrastructure component.
33
+
34
+ Thanks to the many other contributors to the project: Ian Simon, Reiner Pope,
35
+ Vincent Zhao, Pierre Ruyssen, Linting Xue, Junwhan Ahn, Barret Zoph, David
36
+ Dohan, Masumi Parekh, Chang Lan, Frederick Liu, Julien Amelot, Luheng He, Fede
37
+ Lebron, RebeccaChen, Anosh Raj, Mandy Guo, Ethan Dyer, Mihai Tiuca, Hongkun Yu,
38
+ Kevin Brooks, David Soergel, Kelvin Guu, Joshua Ainslie, Luyao Xu, Ji Ma, Josh
39
+ Gardner, Daphne Ippolito, Peter Hawkins, Bo Pang, Marc Rasi, Wei Li, Wenhu Chen,
40
+ Iulia Turc, John Wieting, Alex Passos, Zonglin Li, Katie Everett, Olivier
41
+ Bachem, Francesco Piccinno, Jakub Adamek, Jonathan Heek, Parker Schuh, Hexiang
42
+ Hu, Du Phan, Max Moroz, David Miller, Ryan Doherty, David Elworthy, Alfonso
43
+ Casta ̃no, Julian Eisenschlos, Vlad-Doru Ion, Lucas Dixon, Ron Shapiro, Dinghua
44
+ Li, Aaron Parisi, Xi Chen, Nan Ding, Chung-ching Chang, Timothy Dozat, Natalia
45
+ Ponomareva, Delesley Hutchins, Ankush Garg, Yu-Han Liu, Mehrdad Khatir, Costanza
46
+ Conforti, Philipp Keck, Rapha ̈el Marinier, Marie Pellat, Raghuram Vadapalli,
47
+ Joshua Maynez, Yi Tay, Xihui Wu, David Belanger, Luke Metz, Dan Zheng, Deepti
48
+ Bhatia, Hariharan Shanmugavadivel, Rewon Child, Rigel Swavely, Mihir Sanjay
49
+ Kale, Arash Afkanpour, Roberto Rama, Juro Gottweis, Jonathan Herzig, Yilei Yang,
50
+ Elias Mizan, Pedram Pejman, Jiayu Ye, Smit Sanghavi, Rahul Joshi, Ziqiang Feng,
51
+ Charles Sutton, Weikang Zhou, Liam Fedus, Shanqing Cai, Ginger Perng, Yash
52
+ Katariya, Urvashi Khandelwal, Sebastian Gehrmann, Edward Loper, Tianze Shi, Luke
53
+ Vilnis, Amelia Archer, Tom Weingarten, David Zats, Murtaza Dhuliawala, Xin Xie,
54
+ Sahil Dua, Andr ́e SusanoPinto, Piotr Padlewski, Sascha Rothe, Erik Aas, Felix
55
+ Stahlberg, Ken Durden, Christina Sorokin, Jaehoon Lee, Roy Frostig, Jacob
56
+ Devlin, Jorge Gonzalez Mendez, Deepak Ramachandran, Santiago Ontanon, Karthik
57
+ Raman, Yi Sun, Ali Elqursh, Reuben La Haye,Adam Fahrenkopf, Alex Polozov, Vinay
58
+ Ramasesh, Ian Tenney.
59
+
60
+ Thanks to NVIDIA for GPU contributions: Sahil Jain, Terry Kong, Yu-Hang Tang,
61
+ Ming Huang, Frederic Bastien, Sharath Turuvekere Sreenivas, Xiaowei Ren, Ryan Jeng,
62
+ Reese Wang
63
+
64
+ Thanks to Douglas Eck and Zoubin Ghahramani for sponsoring the project.
t5x-main/docs/index.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5X
2
+
3
+
4
+ Note: T5X is community-supported since ~2023. For critical use cases, consider
5
+ using libraries like TuneLab (go/tunelab) and Gemax Prod (go/gemax-prod). See
6
+ https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-to-gemax-prod for useful tips on transitioning.
7
+
8
+ ## Overview
9
+
10
+ T5X is a modular, composable, research-friendly framework for high-performance,
11
+ configurable, self-service training, evaluation, and inference of sequence
12
+ models (starting with language) at many scales.
13
+
14
+ It is essentially a new and improved implementation of the
15
+ [T5 codebase](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md) (based on Mesh TensorFlow) in JAX and Flax. To learn
16
+ more, see the [T5X Paper](https://arxiv.org/abs/2203.17189).
17
+
18
+ ## Getting Started
19
+
20
+ Here are some quick tutorials to help you get started with common use-cases on
21
+ T5X:
22
+
23
+ #### [Introductory Colabs](tutorials.md)
24
+
25
+ If you are new to T5X, we recommend starting with our introductory Colab series,
26
+ which introduces core concepts of both T5X and SeqIO. More colabs will be added
27
+ to this series regularly!
28
+
29
+ #### [Fine-tuning a model](usage/finetune.md)
30
+
31
+ This tutorial outlines the steps to fine-tune an existing pre-trained model with
32
+ T5X on common downstream Tasks/Mixtures available on SeqIO. This is one of the
33
+ simplest and most common use cases of T5X. If you're new to T5X, this tutorial
34
+ is the recommended starting point.
35
+
36
+ #### [Running evaluation on a model](usage/eval.md)
37
+
38
+ This tutorial outlines the steps to evaluate a model with T5X on downstream
39
+ Tasks/Mixtures defined in SeqIO.
40
+
41
+ #### [Running inference on a model](usage/infer.md)
42
+
43
+ This tutorial outlines the steps to run inference on a model with T5X.
44
+
45
+ #### [Training a model from scratch](usage/pretrain.md)
46
+
47
+ This tutorial outlines the steps to pretrain a model with T5X on Tasks/Mixtures
48
+ defined in SeqIO.
49
+
50
+ #### [Gin Primer](usage/gin.md)
51
+
52
+ This tutorial provides a quick introduction to Gin, a lightweight configuration
53
+ framework for Python that is used to configure training, eval and inference jobs
54
+ on T5X.
55
+
56
+ #### [Partitioning Primer](usage/partitioning.md)
57
+
58
+ This tutorial provides background on what model and data partitioning are and
59
+ how it can be configured in T5X.
60
+
61
+ #### [Metrics Overview](usage/metrics.md)
62
+
63
+ This tutorial provides an overview of how metrics can be used and customized to
64
+ evaluate T5X models.
65
+
t5x-main/docs/index.rst ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ******************************
2
+ T5X
3
+ ******************************
4
+
5
+
6
+ T5X is a modular, composable, research-friendly framework for high-performance,
7
+ configurable, self-service training, evaluation, and inference of sequence
8
+ models (starting with language) at many scales.
9
+
10
+ It is essentially a new and improved implementation of the
11
+ `T5 codebase <https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md>`__
12
+ (based on Mesh TensorFlow) in JAX and Flax. To learn more, see the
13
+ `T5X Paper <https://arxiv.org/abs/2203.17189>`__.
14
+
15
+ .. toctree::
16
+ :maxdepth: 2
17
+ :caption: Table of Contents
18
+
19
+ Quick Start <overview>
20
+ Tutorials <tutorials>
21
+ Usage Guides <usage/index>
22
+ Models <models>
23
+ api_reference/index
24
+ contributions
t5x-main/docs/models.md ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Models
2
+
3
+
4
+ This page lists the available pre-trained T5 models. To use a pre-trained model,
5
+ you need a Gin config file that defines the model params, and the model
6
+ checkpoint to load from. For your convenience, TensorFlow checkpoints and Gin
7
+ configs for common T5 pre-trained models have been made available for use in
8
+ T5X. Following is a list of these pre-trained models and their Gin and
9
+ checkpoint locations.
10
+
11
+ + All checkpoints:
12
+ [`gs://t5-data/pretrained_models/t5x/`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/)
13
+ + All Gin files:
14
+ [`t5x/configs/models/`](https://github.com/google-research/t5x/blob/main/t5x/configs/)
15
+
16
+ ### Selecting a model:
17
+
18
+ Publicly Available Models:
19
+
20
+ Model | Use Case
21
+ ---------------------------------------------------- | --------
22
+ [T5 1.1](#t5-11-checkpoints) | Improved T5, recommended for most research. English only.
23
+ [T5](#t5-checkpoints) | The original T5 work for reproducibility. English only.
24
+ [T5 1.1 LM-Adapted](#t5-11-lm-adapted-checkpoints) | Trained for 100k additional steps on the LM objective, per [prompt tuning paper](https://arxiv.org/abs/2104.08691).
25
+ [mT5](#mt5-checkpoints) | Multilingual T5. Recommended for multilingual research. Note that at smaller scales (at least through XL), mT5 performance is lower than T5 on English tasks.
26
+ [mT5 LM-Adapted](#mt5-lm-adapted-checkpoints) | Trained for 100k additional steps on the LM objective, per [zero-shot cross-lingual generation (XGen) paper](https://arxiv.org/abs/2205.12647).
27
+ [umT5](#umt5-checkpoints) | umT5, an updated mT5 model trained using a more uniform language distribution, per [the UniMax paper](https://openreview.net/forum?id=kXwdL1cWOAi).
28
+ [ByT5](#byt5-checkpoints) | ByT5. A "token-free" model that uses UTF-8 bytes for input and output. Recommended for tasks involving word-internal phenomena such as spelling, pronunciation, or morphology.
29
+ [LongT5](#longt5-checkpoints) | Recommended checkpoints to fine-tune for long input sequence tasks
30
+ [MoE](#mixture-of-experts-moe-checkpoints) | Useful for MoE experimentation.
31
+ [Flan-T5](#flan-t5-checkpoints) | General purpose T5 checkpoints for few-shot and finetuning. We recommend Flan-T5 over vanilla T5 and T5 LM-adapted
32
+ [UL2](#ul2-checkpoints) | Checkpoints for 20B pretrained and FLAN-based instruction-tuned models using the UL2 objective from [UL2 paper](https://arxiv.org/abs/2205.05131)
33
+ [BigScience](#bigscience-checkpoints) | Checkpoints from the [BigScience paper](https://arxiv.org/abs/2204.05832)
34
+ [FLIP](#flip-checkpoints) | Language-Image models trained with an alternative to CLIP, presented in the [FLIP paper](https://arxiv.org/abs/2212.00794)
35
+ [RankGen](#rankgen-checkpoints) | 1.2B parameter encoder model for English to score model generations given a prefix for decoding from the [RankGen paper](https://arxiv.org/abs/2205.09726)
36
+ [Dipper](#dipper-checkpoints) | 11B parameter paraphrase generation model from the [Dipper paper](https://arxiv.org/abs/2303.13408)
37
+
38
+
39
+ ### Public Research Models
40
+
41
+ #### T5 Checkpoints
42
+
43
+ These are the checkpoints used in the paper [Exploring the Limits of Transfer
44
+ Learning with a Unified Text-to-Text
45
+ Transformer](https://arxiv.org/abs/1910.10683). They are encoder-decoder models
46
+ pre-trained on [C4](https://www.tensorflow.org/datasets/catalog/c4) with a "span
47
+ corruption" denoising objective, in addition to a mixture of downstream tasks
48
+ including: GLUE, SuperGLUE, CNN/Daily Mail, SQuAD, and WMT.
49
+
50
+ **Vocabulary:**
51
+ [cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
52
+
53
+ Model | Gin File Location | Checkpoint Location
54
+ -------- | ------------------------------------------------------------------------------ | -------------------
55
+ T5 Small | [t5_small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/small.gin) | [gs://t5-data/pretrained_models/t5x/t5_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_small)
56
+ T5 Base | [t5_base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/base.gin) | [gs://t5-data/pretrained_models/t5x/t5_base/checkpoint_999900](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_base)
57
+ T5 Large | [t5_large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/large.gin) | [gs://t5-data/pretrained_models/t5x/t5_large/checkpoint_1000700](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_large)
58
+ T5 3B | [t5_3B.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/3B.gin) | [gs://t5-data/pretrained_models/t5x/t5_3B/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_3B)
59
+ T5 11B | [t5_11B.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/11B.gin) | [gs://t5-data/pretrained_models/t5x/t5_11B/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_11B)
60
+
61
+ #### T5 1.1 Checkpoints
62
+
63
+ These are similar to the models from [Exploring the Limits of Transfer Learning
64
+ with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683), but
65
+ with the following improvements:
66
+
67
+ * GEGLU activation in feed-forward hidden layer, rather than ReLU - see
68
+ https://arxiv.org/abs/2002.05202 .
69
+ * Dropout was turned off in pre-training (quality win). Dropout should be
70
+ re-enabled during fine-tuning.
71
+ * Pre-trained on C4 only without mixing in the downstream tasks.
72
+ * no parameter sharing between embedding and classifier layer
73
+ * "xl" and "xxl" replace "3B" and "11B". The model shapes are a bit
74
+ different - larger d_model and smaller num_heads and d_ff.
75
+
76
+ For English-language, sequence-to-sequence-style tasks (ones where the goal is
77
+ to map from an input text sequence to a target sequence) these are usually the
78
+ best models to fine-tune.
79
+
80
+ **Vocabulary:**
81
+ [cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
82
+
83
+ Model | Gin File Location | Checkpoint Location
84
+ ------------ | ---------------------------------------------------------------------------------- | -------------------
85
+ T5 1.1 Small | [t5_1_1/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_small)
86
+ T5 1.1 Base | [t5_1_1/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/base.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_base)
87
+ T5 1.1 Large | [t5_1_1_large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/large.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_large)
88
+ T5 1.1 XL | [t5_1_1_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xl.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_xl)
89
+ T5 1.1 XXL | [t5_1_1_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xxl.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_xxl)
90
+
91
+ #### T5 1.1 LM-Adapted Checkpoints
92
+
93
+ These "LM-adapted" models are initialized from T5 1.1 (above) and trained for an
94
+ additional 100K steps on the LM objective discussed in the
95
+ [T5 paper](https://arxiv.org/abs/1910.10683). This adaptation improves the
96
+ ability of the model to be used for
97
+ [prompt tuning](https://arxiv.org/abs/2104.08691). These checkpoints were also
98
+ used within the BigScience [T0](https://arxiv.org/abs/2110.08207) project.
99
+
100
+ **Vocabulary:**
101
+ [cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
102
+
103
+ Model | Gin File Location | Checkpoint Location
104
+ -------------------- | ------------------------------------------------------------------------------------------------------------------- | -------------------
105
+ T5 1.1 LM-100K Small | [t5_1_1_small.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin) | [t5_1_1_lm100k_small/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_small)
106
+ T5 1.1 LM-100K Base | [t5_1_1_base.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_base.gin) | [t5_1_1_lm100k_base/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_base)
107
+ T5 1.1 LM-100K Large | [t5_1_1_large.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_large.gin) | [t5_1_1_lm100k_large/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_large)
108
+ T5 1.1 LM-100K XL | [t5_1_1_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_xl.gin) | [t5_1_1_lm100k_xl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl)
109
+ T5 1.1 LM-100K XXL | [t5_1_1_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_xxl.gin) | [t5_1_1_lm100k_xxl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_xxl)
110
+
111
+
112
+ #### mT5 Checkpoints
113
+
114
+ These are the checkpoints used in the paper
115
+ [mT5: A Massively Multilingual Pre-trained Text-to-Text Transformer](https://aclanthology.org/2021.naacl-main.41/).
116
+ They are encoder-decoder models trained on
117
+ [multilingual C4](https://www.tensorflow.org/datasets/catalog/c4#c4multilingual)
118
+ with a denoising objective. These are the best checkpoints to fine-tune for
119
+ non-English sequence-to-sequence tasks.
120
+
121
+ **Vocabulary:**
122
+ [mc4.250000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/mc4.250000.100extra)
123
+
124
+ Model | Gin File Location | Checkpoint Location
125
+ --------- | ---------------------------------------------------------------------------- | -------------------
126
+ mT5 Small | [mt5/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/small.gin) | [gs://t5-data/pretrained_models/t5x/mt5_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_small)
127
+ mT5 Base | [mt5/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/base.gin) | [gs://t5-data/pretrained_models/t5x/mt5_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_base)
128
+ mT5 Large | [mt5/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/large.gin) | [gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_large)
129
+ mT5 XL | [mt5/xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xl.gin) | [gs://t5-data/pretrained_models/t5x/mt5_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_xl)
130
+ mT5 XXL | [mt5/xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xxl.gin) | [gs://t5-data/pretrained_models/t5x/mt5_xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_xxl)
131
+
132
+ #### mT5 LM-Adapted Checkpoints
133
+
134
+ These are the checkpoints released as part of the
135
+ [zero-shot cross-lingual generation (XGen) paper](https://arxiv.org/abs/2205.12647).
136
+
137
+ These "LM-adapted" models are initialized from mT5 (above) and trained for an
138
+ additional 100K steps on the LM objective discussed in the
139
+ [T5 paper](https://arxiv.org/abs/1910.10683).
140
+
141
+ This adaptation improves the ability of the model to be used for
142
+ [prompt tuning](https://arxiv.org/abs/2104.08691).
143
+
144
+ **Vocabulary:**
145
+ [mc4.250000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/mc4.250000.100extra)
146
+
147
+ Model | Gin File Location | Checkpoint Location
148
+ -------------------- | ---------------------------------------------------------------------------- | -------------------
149
+ mT5 LM-Adapted Small | [mt5/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/small.gin) | [mt5_lm_adapted/small/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/small/checkpoint_1100000)
150
+ mT5 LM-Adapted Base | [mt5/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/base.gin) | [mt5_lm_adapted/base/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/base/checkpoint_1100000)
151
+ mT5 LM-Adapted Large | [mt5/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/large.gin) | [mt5_lm_adapted/large/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/large/checkpoint_1100000)
152
+ mT5 LM-Adapted XL | [mt5/xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xl.gin) | [mt5_lm_adapted/xl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/xl/checkpoint_1100000)
153
+ mT5 LM-Adapted XXL | [mt5/xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xxl.gin) | [mt5_lm_adapted/xxl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/xxl/checkpoint_1100000)
154
+
155
+ #### umT5 Checkpoints
156
+
157
+ These are the checkpoints described in the paper [UniMax: Fairer and More
158
+ Effective Language Sampling for Large-Scale Multilingual
159
+ Pretraining](https://openreview.net/forum?id=kXwdL1cWOAi). umT5 is similar to
160
+ mT5 (see above); both are multilingual encoder-decoder models ranging from 300M
161
+ to 13B parameters, trained on the mC4 corpus using a denoising objective. umT5
162
+ is trained on a fresher version of the mC4 corpus (3.1.0), and with a more
163
+ uniform language balancing strategy.
164
+
165
+ **Vocabulary:** [umt5.256000](https://console.cloud.google.com/storage/browser/t5-data/vocabs/umt5.256000)
166
+
167
+ Model | Gin File Location | Checkpoint Location
168
+ ---------- | --------------------------------------------------------------------------------------------------------- | -------------------
169
+ umT5 Small | [umt5/pretrain_small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_small.gin) | [umt5/small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/small/checkpoint_1000000)
170
+ umT5 Base | [umt5/pretrain_base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_base.gin) | [umt5/base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/base/checkpoint_1000000)
171
+ umT5 XL | [umt5/pretrain_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_xl.gin) | [umt5/xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/xl/checkpoint_1000000)
172
+ umT5 XXL | [umt5/pretrain_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_xxl.gin) | [umt5/xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/xxl/checkpoint_1000000)
173
+
174
+ #### ByT5 Checkpoints
175
+
176
+ These are the checkpoints used in the paper
177
+ [ByT5: Towards a Token-Free Future with Pre-trained Byte-to-Byte Models](https://aclanthology.org/2022.tacl-1.17/).
178
+ They are similar to mT5 (above), but are "token-free", processing text as raw
179
+ UTF-8 bytes, as opposed to using a pretrained subword vocabulary. These models
180
+ are more robust to character-level noise, and outperform parameter-matched mT5
181
+ models in many settings, particularly on word-level tasks sensitive to spelling,
182
+ pronunciation, or morphology. However inference is significantly slower, up to
183
+ 10x depending on the task.
184
+
185
+ **Vocabulary:** None
186
+
187
+ Model | Gin File Location | Checkpoint Location
188
+ ---------- | ------------------------------------------------------------------------------ | -------------------
189
+ ByT5 Small | [byt5/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/small.gin) | [gs://t5-data/pretrained_models/t5x/byt5_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_small)
190
+ ByT5 Base | [byt5/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/base.gin) | [gs://t5-data/pretrained_models/t5x/byt5_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_base)
191
+ ByT5 Large | [byt5/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/large.gin) | [gs://t5-data/pretrained_models/t5x/byt5_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_large)
192
+ ByT5 XL | [byt5/xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/xl.gin) | [gs://t5-data/pretrained_models/t5x/byt5_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_xl)
193
+ ByT5 XXL | [byt5/xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/xxl.gin) | [gs://t5-data/pretrained_models/t5x/byt5_xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_xxl)
194
+
195
+ #### LongT5 Checkpoints
196
+
197
+ These are the checkpoints used in the paper
198
+ [LongT5: Efficient Text-to-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916).
199
+ They are encoder-decoder models trained on
200
+ [C4](https://www.tensorflow.org/datasets/catalog/c4) using the PEGASUS Principle
201
+ Sentences Generation objective. These are the recommended checkpoints to
202
+ fine-tune for long input sequence tasks.
203
+
204
+ ##### LongT5 Local Attention Checkpoints
205
+
206
+ The checkpoints below use local attention, which uses a sliding window to reduce
207
+ training time from quadratic (with regards to input length) to linear. These are
208
+ the recommended checkpoints to use for faster training/inference time.
209
+
210
+ **Vocabulary:**
211
+ [cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
212
+
213
+ Model | Gin File Location | Checkpoint Location
214
+ ---------------------------- | ------------------------------------------------------------------------------------------------------------------------------------- | -------------------
215
+ LongT5 Local Attention Base | [longt5/models/longt5_1_1_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_base.gin) | [gs://t5-data/pretrained_models/t5x/longt5/local_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/local_base)
216
+ LongT5 Local Attention Large | [longt5/models/longt5_1_1_large.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_large.gin) | [gs://t5-data/pretrained_models/t5x/longt5/local_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/local_large)
217
+
218
+ ##### LongT5 Transient Global Attention Checkpoints
219
+
220
+ The checkpoints below use transient global attention, which introduces global
221
+ tokens at each encoder layer to allow tokens to interact with each other at
222
+ longer distances. These are the recommended checkpoints to use for increased
223
+ performance on long input sequence tasks.
224
+
225
+ **Vocabulary:**
226
+ [cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
227
+
228
+ Model | Gin File Location | Checkpoint Location
229
+ ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------
230
+ LongT5 Base | [longt5/models/longt5_1_1_transient_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_transient_global_base.gin) | [gs://t5-data/pretrained_models/t5x/longt5/tglobal_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/tglobal_base)
231
+ LongT5 Large | [longt5/models/longt5_1_1_transient_large.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_transient_global_large.gin) | [gs://t5-data/pretrained_models/t5x/longt5/tglobal_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/tglobal_large)
232
+ LongT5 XL | [longt5/models/longt5_1_1_transient_xl.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_transient_global_xl.gin) | [gs://t5-data/pretrained_models/t5x/longt5/tglobal_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/tglobal_xl)
233
+
234
+ #### Mixture of Experts (MoE) Checkpoints
235
+
236
+ These MoE checkpoints need to be used with T5X MoE overrides -- specifically,
237
+ the MoeTrainer and the MoePjitPartitioner. For example, for fine-tuning, use the
238
+ [MoE fine-tune run config](https://github.com/google-research/t5x/blob/main/t5x/contrib/moe/configs/runs/finetune.gin).
239
+
240
+
241
+ ##### Converted Mesh Tensorflow checkpoints
242
+
243
+ [Switch Transformer model](https://arxiv.org/abs/2101.03961).
244
+
245
+ **Vocabulary:**
246
+ [cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
247
+
248
+
249
+ Model | Gin File Location | Checkpoint Location
250
+ ---------------------------------------- | ------------------------------------------------------------------------------------------------------------ | -------------------
251
+ Switch Transformer Base 8 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e8/checkpoint_500100](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e8)
252
+ Switch Transformer Base 16 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e16/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e16)
253
+ Switch Transformer Base 32 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e32/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e32)
254
+ Switch Transformer Base 64 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e64/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e64)
255
+ Switch Transformer Base 128 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e128/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e128)
256
+ Switch Transformer Base 256 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e256/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e256)
257
+ Switch Transformer Large 128 Experts | [switch_large.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_large.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/large/e128/checkpoint_483100](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/large/e128)
258
+ Switch Transformer XXL 128 Experts | [switch_xxl.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_xxl.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/xxl/e128/checkpoint_634600](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/xxl/e128)
259
+ Switch Transformer C 2048 Experts (1.6T) | [switch_c.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_c.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/c/e2048/checkpoint_611800](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/c/e2048)
260
+
261
+
262
+ #### Flan-T5 Checkpoints
263
+
264
+ These are the checkpoints released as part of the paper
265
+ [Scaling Instruction-Finetuned Language Models](https://arxiv.org/abs/2210.11416).
266
+ They were initialized from the
267
+ [T5 1.1 LM-Adapted](#t5-11-lm-adapted-checkpoints) and instruction-finetuned.
268
+
269
+ They significantly outperform the LM-adapted checkpoints. For example,
270
+ Flan-T5-XXL outperforms T5-LM-XXL by 26.6% absolute on the normalized average
271
+ score. It even outperforms a much larger PaLM 62B model on
272
+ [BigBench Hard](https://arxiv.org/abs/2210.09261) a set of challenging BigBench
273
+ benchmark.
274
+
275
+ Unlike the vanilla T5 checkpoints, these can be directly used for few-shot
276
+ prompting as well as standard finetuning. See
277
+ [Chung et al. 2022](https://arxiv.org/abs/2210.11416) for details.
278
+
279
+ Model | Gin File Location | Checkpoint Location
280
+ ------------- | ---------------------------------------------------------------------------------- | -------------------
281
+ Flan-T5 Small | [t5_1_1/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_small/checkpoint_1198000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_small/checkpoint_1198000)
282
+ Flan-T5 Base | [t5_1_1/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/base.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_base/checkpoint_1184000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_base/checkpoint_1184000)
283
+ Flan-T5 Large | [t5_1_1_large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/large.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_large/checkpoint_1164000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_large/checkpoint_1164000)
284
+ Flan-T5 XL | [t5_1_1_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xl.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_xl/checkpoint_1138000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_xl/checkpoint_1138000)
285
+ Flan-T5 XXL | [t5_1_1_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xxl.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_xxl/checkpoint_1114000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_xxl/checkpoint_1114000)
286
+
287
+ #### UL2 Checkpoints
288
+
289
+ Checkpoints for 20B pretrained and FLAN-based instruction-tuned models using the
290
+ UL2 objective from [UL2 paper](https://arxiv.org/abs/2205.05131). Checkpoints
291
+ are released at
292
+ https://github.com/google-research/google-research/tree/master/ul2#checkpoints.
293
+
294
+ #### BigScience Checkpoints
295
+
296
+ Checkpoints from the [BigScience paper](https://arxiv.org/abs/2204.05832),
297
+ released at
298
+ https://github.com/bigscience-workshop/architecture-objective/tree/main#checkpoints.
299
+
300
+ #### FLIP Checkpoints
301
+
302
+ Language-Image models trained with an alternative to CLIP, presented in the
303
+ [FLIP paper](https://arxiv.org/abs/2212.00794). Checkpoints are released at
304
+ https://github.com/facebookresearch/flip#results-and-pre-trained-flip-models.
305
+
306
+ #### RankGen Checkpoints
307
+
308
+ 1.2B parameter encoder model for English to score model generations given a
309
+ prefix for decoding from the [RankGen paper](https://arxiv.org/abs/2205.09726).
310
+ Checkpoints are released at
311
+ https://github.com/google-research/google-research/tree/master/rankgen.
312
+
313
+ #### Dipper Checkpoints
314
+
315
+ 11B parameter paraphrase generation model from the
316
+ [Dipper paper](https://arxiv.org/abs/2303.13408). Checkpoints are released at
317
+ https://github.com/google-research/google-research/tree/master/dipper.
318
+
t5x-main/docs/overview.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ```{include} ../README.md
2
+ ```
t5x-main/docs/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ sphinx>=4.4.0
2
+ myst_parser>=0.16.1
3
+ myst_nb
4
+ sphinx-design
5
+ sphinx-book-theme
6
+
7
+ # Must install t5x itself for notebook execution and autodocs to work.
8
+ .
t5x-main/docs/t5x.png ADDED

Git LFS Details

  • SHA256: 5e903d6a7cb99b192a23b895cd30157d5661cd0e895b3f1d6f2027fdfb1b66dd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.84 MB
t5x-main/docs/tutorials.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5X Introductory Tutorial Series
2
+
3
+
4
+ ## Overview
5
+
6
+ This series of guides is a self-contained introduction to T5X, a modular,
7
+ composable, research-friendly framework for high-performance, configurable,
8
+ self-service training, evaluation, and inference of sequence models (starting
9
+ with language) at many scales.
10
+
11
+
12
+ ## How to Use These Guides
13
+
14
+ Most entries in this series are colab notebooks (click the blue banners to the
15
+ right of each heading below), allowing you to run our tutorial code
16
+ interactively. We encourage you to do that! Play around, change things, see what
17
+ happens!
18
+
19
+
20
+ ## T5X Guides
21
+
22
+ ### Codelab 1: An Introduction to T5X
23
+
24
+ <a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in colab" style="float:left"/></a><br>
25
+
26
+ In this colab, you will learn about some of the basic T5X components and put
27
+ them to use to run training, inference, and evaluation on natural text inputs.
28
+
29
+ ### Codelab 2: Training Deep Dive
30
+
31
+ <a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in colab" style="float:left"/></a><br>
32
+
33
+ In this colab, you will dive into how to restore T5X models from checkpoints and
34
+ run training, while also getting an introduction to the T5X trainer.
35
+
36
+ ### Codelab 3: Inference Deep Dive
37
+
38
+ <a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in colab" style="float:left"/></a><br>
39
+
40
+ In this colab, you will dive into how the Interactive Model does decoding to
41
+ generate predictions and scores for a given input.
42
+
43
+ ### Codelab 4: Evaluation Deep Dive
44
+
45
+ <a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in colab" style="float:left"/></a><br>
46
+
47
+ In this colab, you will dive into how the InteractiveModel takes a batch of
48
+ inputs and targets and runs evaluation to produce various metrics.
49
+
50
+
51
+ ### More Colabs coming soon!
t5x-main/docs/usage/auxiliary.md ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Auxiliary Job
2
+
3
+
4
+ ## Introduction
5
+
6
+ This page outlines the steps needed to use the auxiliary job capabilities
7
+ available in T5X.
8
+
9
+ ## Overview
10
+
11
+ There are a variety of situations in which running a single job is insufficient
12
+ or suboptimal. For example, consider the following scenarios:
13
+
14
+ + You want to keep track of evaluation (`infer_eval` or `train_eval`) metrics
15
+ per checkpoint, but evaluation takes a very long time due to having a large
16
+ eval dataset, slow decoding, or multiple tasks to evaluate.
17
+
18
+ + You want to finetune every checkpoint on a downstream task as you train.
19
+
20
+ + You have customized evaluation code that you want to run on every checkpoint
21
+ as you train, but that does not naturally fit within a `seqio.Evaluator`
22
+ framework.
23
+
24
+ In cases like these, users can make use of the auxiliary job functionality. At a
25
+ high-level, the auxiliary job will launch a new job every time a new checkpoint
26
+ is saved. This new job can either re-use the `train.py` binary (e.g. for
27
+ continuous finetuning) or a different one. For example, this allows users to
28
+ perform continuous evaluation (using `eval.py`) without slowing down the
29
+ training job. We will provide detailed examples showing how to use the auxiliary
30
+ job for these use-cases.
31
+
32
+ When this new job is launched, the controller will replace four gin macros:
33
+ `MODEL_DIR`, `MIXTURE_OR_TASK_NAME`,`INITIAL_CHECKPOINT_PATH`, `TRAIN_STEPS`.
34
+ The second of these is set by the user-controlled flag (more on this below), and
35
+ the third one is equal to the last checkpoint seen. Aside from this, users are
36
+ free to modify the configuration as needed. Beyond gin macros, the auxiliary job
37
+ can also have different resource requirements, priority, and even cell placement
38
+ from the train job.
39
+
40
+ ## Example 1: Separate evaluation job.
41
+
42
+ ### Step 1: Choose a model architecture.
43
+
44
+ Similar to pretraining, we will need some gin configuration. For this example,
45
+ we will use the T5-1.1-Base model.
46
+
47
+ ### Step 2: Choose a SeqIO Task/Mixture for training and evaluation.
48
+
49
+ In this example, we will use the classic task of English-French translation from
50
+ WMT14, which is conveniently available as a SeqIO task in the tasks file from
51
+ the T5 tasks under the name `'wmt_enfr14_v003'`.
52
+
53
+ ### Step 3: Write a Gin config.
54
+
55
+ Unlike pretraining or finetuning, we will need two gin files for this setup: one
56
+ for the training job, and one for the auxiliary job. The train gin file will
57
+ have the same requirements as the gin file for pretraining or finetuning. The
58
+ auxiliary job gin file can leverage these gin files or be its own independent
59
+ gin file, depending on the user’s choice. For this example, we will make a new
60
+ gin which is mostly a wrapper around `pretrain.gin` with some additional
61
+ hardcoded features. We will use this gin file for the train job and `eval.gin`
62
+ for the auxiliary job.
63
+
64
+ ### Step 4: Launch your experiment.
65
+
66
+ Our sample script will be quite similar to the one used in pretraining and
67
+ finetuning, but with a few additional flags which we describe below.
68
+
69
+ + `auxiliary_job_mixtures`: This is a comma-separated list of mixtures. A
70
+ separate auxiliary job will be run for each mixture and will replace the gin
71
+ macro `MIXTURE_OR_TASK_NAME`. Note that you need this flag even if you are
72
+ using a custom binary, which does not need a mixture since otherwise no
73
+ auxiliary job will run.
74
+
75
+ + `auxiliary_job_gin_file`: This is identical to `gin_file`, except it is used
76
+ for the auxiliary job instead of the train job.
77
+
78
+ + `replace_gin_file`: If True, this auxiliary launcher will not use any of the
79
+ gin files from train job. This is necessary when using a binary different
80
+ from `train.py`, since the top-level functions will not match.
81
+
82
+ + `auxiliary_job_cell`: The cell in which to run your job. Note that this can
83
+ be different from the training cell.
84
+
85
+ + `auxiliary_job_platform`: The platform to use for the auxiliary. Note that
86
+ this can be different from the one use for the train job, allowing users to
87
+ use smaller configurations for evaluation than needed for training.
88
+
89
+ + `auxiliary_job_build_target`: The binary to use for auxiliary job.
90
+
91
+ + `final_auxiliary_job_steps`: This flag controls how many additional steps to
92
+ take when using the auxiliary job for finetuning. Setting to 0 enables
93
+ continuous evaluation.
94
+
95
+ We provide the sample script below.
96
+
97
+ ```sh
98
+ declare -a ARGS=(
99
+ --cell=iz
100
+ --platform=jd=2x2
101
+ --final_auxiliary_job_steps=0
102
+ --replace_gin_file=True
103
+ --auxiliary_job_mixtures=wmt14_enfr_v003
104
+ --auxiliary_job_gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_eval.gin
105
+ --auxiliary_job_cell=iz
106
+ --auxiliary_job_platform=jd=2x2
107
+ --auxiliary_job_build_target_path=//t5x:eval
108
+ --gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_train.gin
109
+ )
110
+
111
+ gxm t5x/google/xm_launch.py "${ARGS[@]}"
112
+ ```
113
+
114
+ ## Example 2: Continuous finetuning job.
115
+
116
+ In this example, we will be pretraining a model on a span corruption task on the
117
+ C4 dataset, and finetuning it on the WMT'14 English-French translation task. As
118
+ before, we will launch a new auxiliary job once every checkpoint is saved.
119
+ However, instead of using the `eval.py` binary, we will use the `train.py`
120
+ binary.
121
+
122
+ ### Step 1: Choose a model architecture.
123
+
124
+ We will use the T5-1.1-Base model as in the previous example.
125
+
126
+ ### Step 2: Choose a SeqIO Task/Mixture for training and evaluation.
127
+
128
+ For pretraining, we re-use the span coprruption task `c4_v220_span_corruption`
129
+ available in the T5 mixtures `tasks.py` file.
130
+
131
+ ### Step 3: Write a Gin config.
132
+
133
+ As before, we need our gin files to contain all the desired macros in them. We
134
+ thus create two new gin files: `base_c4_pretrain.gin` for the train job and
135
+ `base_wmtenfr14_finetune.gin` for the auxiliary job.
136
+
137
+ ### Step 4: Launch your experiment.
138
+
139
+ Our script is quite similar to the first example, with the same flags as before
140
+ but with the appropiate changes. The main distinction is that we must change the
141
+ flag `final_auxiliary_job_steps` to be non-zero to start finetuning. We will
142
+ settle for a modest 200 steps for the sake of demonstration (and evaluate every
143
+ 100 steps), but users should use larger steps in realistic scenarios. We also
144
+ use `train.py` binary instead of `eval.py`.
145
+
146
+ We provide the sample script below.
147
+
148
+ ```sh
149
+ declare -a ARGS=(
150
+ --cell=iz
151
+ --platform=jd=2x2
152
+ --final_auxiliary_job_steps=200
153
+ --replace_gin_file=True
154
+ --auxiliary_job_mixtures=wmt14_enfr_v003
155
+ --auxiliary_job_gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin
156
+ --auxiliary_job_cell=iz
157
+ --auxiliary_job_platform=jd=2x2
158
+ --auxiliary_job_build_target_path=//t5x:train
159
+ --gin_file=t5x/examples/t5/t5_1_1/examples/base_c4_pretrain.gin
160
+ )
161
+
162
+ gxm t5x/google/xm_launch.py "${ARGS[@]}"
163
+ ```
164
+
165
+ ## Common Gotchas.
166
+
167
+ We outline a few common error patterns that we have encountered.
168
+
169
+ + **Not passing a value for the `auxiliary_mixtures` flag.** Even if you have
170
+ the desired task in your gin file, or you use a differently named macro, you
171
+ should still pass a value for this flag, since launch script will launch a
172
+ new job per value of this flag.
173
+
174
+ + **Not setting `replace_gin_file=True` when using a different binary from
175
+ train.py.** This will usually yield an error that there is no `train`
176
+ function.
177
+
178
+ + **No metrics being logged.** It can be tempting to use gin files usually
179
+ used for evaluation. However, one must ensure that the corresponding SeqIO
180
+ evaluators still log to the tensorboard, otherwise you won’t see the
181
+ metrics.
182
+
183
+ + **Slow `train_eval`.** While the approach outlined above separates out the
184
+ infer_eval job, it may be that even train_eval is too slow. In these
185
+ situations, we suggest adding the metrics from train_eval into the
186
+ `metrics_fn` argument of the SeqIO task and have them be computed in the
187
+ auxiliary job as well. To do this with teacher forcing, you will have to use
188
+ `train.py` instead of `eval.py`.
189
+
190
+ + **Using `CHECKPOINT_PATH` rather `INITIAL_CHECKPOINT_PATH`.** For legacy
191
+ reasons, the auxiliary job uses the macro `INITIAL_CHECKPOINT_PATH` rather
192
+ than `CHECKPOINT_PATH` as found in `eval.gin`. Make sure to use the latter
193
+ macro building your gin scripts.
194
+
195
+ + **Gin macros being ignored when passed through the format
196
+ `gin.{MACRO}={VAL}`.** In the current setup, you must include all gin macros
197
+ in the gin script. Attempting to pass them as additional flags will usually
198
+ not work.
199
+
200
+ + **Not setting `final_auxiliary_job_steps=0` when performing continuous
201
+ evaluation.** The current parameter controller uses this as a check. When
202
+ this is true, it will replace the `EVAL_OUTPUT_DIR` folder with the current
203
+ `MODEL_DIR`, so that the evaluation metrics are saved in the right place and
204
+ the metrics are showed correctly on the tensorboard.
t5x-main/docs/usage/decoding.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Decoding
2
+
3
+
4
+ This page outlines the decoding functions that T5X provides out-of-the-box and
5
+ how custom decoding functions can be used for a Transformer model, i.e., an
6
+ instance of
7
+ [`BaseTransformerModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbBaseTransformerModel%5Cb).
8
+ Here we refer to decoding as a process of generating a sequence of items from a
9
+ fixed alphabet (e.g., generating token ids from the vocabulary).
10
+
11
+ There are two major ways to configure the decoding routine. The first method is
12
+ to define a decode function that follows the `DecodeFnCallable` signature. This
13
+ is more restrictive as it enforces the call signature but users don't need to
14
+ modify the model code.
15
+
16
+ The second method is to subclass a model class and override
17
+ `predict_batch_with_aux` method. While this provides more flexibility, it
18
+ requires rewriting the method.
19
+
20
+ ## Option 1: defining a decoding function
21
+
22
+ If a desired decoding process can follow `DecodeFnCallable`, it can be
23
+ registered as a private attribute of a
24
+ [`BaseTransformerModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbBaseTransformerModel%5Cb)
25
+ by passing it as a `decode_fn` argument to its constructor.
26
+
27
+ ### Decoding function call signature
28
+
29
+ `DecodeFnCallable` has the following call signature
30
+
31
+
32
+ It takes in `inputs`, which is an int32 array with a shape `[batch_size,
33
+ max_decode_len]`. This is an input tokens to the decoder. For the standard
34
+ encoder-decoder models like T5, this is initialized as zeros with a desired
35
+ decoding length. The decoding function will populate the array with the sampled
36
+ token ids and return.
37
+
38
+ For a decoder-only architectures such as a Prefix Language Model, `inputs` can
39
+ be a concatenated sequence of "inputs" and "targets" tokens ids.
40
+
41
+ `tokens_to_logits` is a callable that takes in a batch of token ids and the
42
+ current autoregressive cache, performs the forward pass and returns the
43
+ resulting logits resulting and an updated cache. Note that for incremental
44
+ decoding, this function operates with a single token, i.e., the length dimension
45
+ is assumed to be 1.
46
+
47
+ `DecodeFnCallable` is designed to be as general as possible. This results in
48
+ some of the arguments being somewhat generic for a specialized decoding
49
+ algorithm. For example, `num_decodes` refers to the number of decoded samples to
50
+ be returned. In the case of beam search, `num_decodes` corresponds to what is
51
+ commonly known as `beam_size`, with returned sequences sorted by the beam
52
+ scores. For temperature sampling, we perform `num_decodes` *independent*
53
+ sampling procedures with different random seeds and sort them by the log
54
+ probability of the generated sequences.
55
+
56
+ For custom decoding functions, there might be additional arguments. To support
57
+ these, we provide `**kwargs`.
58
+
59
+ Another usage of `**kwargs` is calling `decoding_fn` multiple times without
60
+ recompiling the model. This pattern is used in
61
+ [Prediction Service](https://github.com/google-research/t5x/blob/main/t5x/google/prediction_service/README.md).
62
+ For a compiled model, different values of `alpha` can be passed e.g.,
63
+ `decoder_params = {"alpha": 0.7}` where `decoder_params` is the argument to
64
+ `predict_batch_with_aux`. It is unpacked and passed to `beam_search` function.
65
+ Note that the Prediction Service uses
66
+ [`predict_batch_with_aux`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=func:%5Cbpredict_batch_with_aux%5Cb),
67
+ which is one of the two public methods. This method is useful if auxiliary
68
+ outputs (e.g., scores of the predictions) are to be returned. The other method
69
+ is
70
+ [`predict_batch`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=func:%5Cbpredict_batch%5Cb),
71
+ which simply returns the predictions.
72
+
73
+ ### Beam search
74
+
75
+ The following lines can be added to a gin file in order to use
76
+ [beam search](https://github.com/google-research/t5x/blob/main/t5x/decoding.py;l=881;rcl=446762159)
77
+ as a decoding function for an encoder-decoder model.
78
+
79
+ ```gin
80
+ models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
81
+ models.EncoderDecoderModel.decode_fn = @decoding.beam_search
82
+ decode.beam_search.alpha = 0.6
83
+ ```
84
+
85
+ Note that we skip the gin boilerplate code such as gin dynamic registration.
86
+ Please refer to [T5X Gin Primer](gin.md) for more details.
87
+
88
+ The beam search behavior is controlled by the arguments passed to `beam_search`.
89
+ We provide details for a few of them below.
90
+
91
+ #### `num_decodes`
92
+
93
+ If `num_decodes` are configured with `gin.register`, it is overridden by the
94
+ value explicitly passed by the caller e.g.,
95
+ `models.EncoderDecoderModel.predict_batch_with_aux`. This is because the
96
+ information about `num_decodes` is needed to prepare the encoder inputs and
97
+ outputs expanded by `num_decodes` times in the batch dimension.
98
+
99
+ We recommend that `num_decodes` be specified *only* in
100
+ `models.EncoderDecoderModel.predict_batch_with_aux`.
101
+
102
+ #### `alpha`
103
+
104
+ This is the brevity penalty introduced in
105
+ [Wu et al. 2016](https://arxiv.org/abs/1609.08144) to penalize short sequences.
106
+
107
+ #### `max_decode_len`
108
+
109
+ For evaluation, we typically don't want to truncate the examples by a specified
110
+ sequence length. Therefore, we dynamically obtain the length information from
111
+ the batch of examples. The default behavior of `seqio.Evaluator` is to use the
112
+ maximum length of a task but, this can be overridden.
113
+
114
+ Since the length information is provided dynamically, we don't set
115
+ `max_decode_len` in gin. Instead we pass the relevant `inputs` array to
116
+ `beam_search` whose length is the dynamically determined maximum length.
117
+
118
+ If `max_decode_len` is explicitly specified via gin, this will override the
119
+ implicitly determined length information unless it is passed by
120
+ `predict_batch_with_aux`.
121
+
122
+ ### Temperature sampling
123
+
124
+ [Temperature sampling](https://github.com/google-research/t5x/blob/main/t5x/decoding.py;l=37;rcl=446762159)
125
+ can be used for multiple decoding strategies. The following lines configures
126
+ temperature sampling as a `decode_fn`.
127
+
128
+ ```gin
129
+ models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
130
+ models.EncoderDecoderModel.decode_fn = @decoding.temperature_sample
131
+ decoding.temperature_sample:
132
+ temperature = 0.5
133
+ topk = 20
134
+ ```
135
+
136
+ Similar specification can be used for other model types by replacing
137
+ `models.EncoderDecoderModel` with the relevant model class, e.g.
138
+ `models.PrefixLanguageModel`.
139
+
140
+ The sampling behavior is controlled by the arguments passed to
141
+ `temperature_sample`. We provide details for a few of them below.
142
+
143
+ #### `temperature`
144
+
145
+ A probabilistic model outputs a probability distribution over a pre-defined
146
+ alphabet. For example, a language model outputs *logits*, which are unnormalized
147
+ probability values for each item in the vocabulary. We use a language model as a
148
+ running example. A sampling process involves *sampling* from the predicted
149
+ distribution one item at a time conditioned on the previously generated items
150
+ until a given number of items are generated or a sentinel token that represents
151
+ the end of sequence is generated.
152
+
153
+ Temperature modifies the unnormalized probability distribution at each step. For
154
+ each item $$i$$ in the vocabulary, its probability predicted by the model is
155
+ given by
156
+
157
+ $$p_i \propto \exp\left(\frac{x_i}{T} \right)$$
158
+
159
+ where $$T$$ is the temperature and $$x_i$$ is the logits value corresponding to
160
+ item $$i$$. As $$T \to 0$$, the distribution puts all probability mass to the
161
+ item with the highest probability. In other words, the sampling process becomes
162
+ a greedy search.
163
+
164
+ In the other extreme, as $$T \to \infty$$, the predicted distribution becomes
165
+ uniform.
166
+
167
+ #### `topk`
168
+
169
+ By specifying strictly positive integer value for `topk`, the sampling process
170
+ in each step is limited to the `k` items with highest probabilities. `topk` also
171
+ uses `temperature` to modify the logits corresponding to the top `k` items.
172
+
173
+ #### `topp`
174
+
175
+ By specifying non-zero positive float value for `topp`, the sampling process is
176
+ limited to a subset of the vocabulary $$V^{(p)} \subset V$$, which is defined by
177
+ the smallest set such that
178
+
179
+ $$\sum_{i \in V^{(p)}} p_i \ge p$$
180
+
181
+ where $$p_i$$ is the conditional distribution at each time step for item $$i$$.
182
+ This is called "Nucleus sampling", which was introduced by
183
+ [Holtzman et al. ICLR 2020](https://openreview.net/forum?id=rygGQyrFvH).
184
+
185
+ IMPORTANT: Only one of `topk` or `topp` can be used.
186
+
187
+ ## Option 2: subclassing a model class
188
+
189
+ If `DecodeFnCallable` is not flexible enough for your custom decoding function,
190
+ you can subclass the model class and override `predict_batch_with_aux` method.
191
+ While the model class can be any instance of
192
+ [`BaseTransformerModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbBaseTransformerModel%5Cb),
193
+ we recommend that you subclass the existing models such as
194
+ [`EncoderDecoderModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbEncoderDecoderModel%5Cb)
195
+ and only override `predict_batch_with_aux` method.
196
+
197
+ `predict_batch_with_aux` method also has a required call signature, but it is
198
+ significantly more flexible. It should return a tuple of predicted sequence
199
+ array and auxiliary outputs such as score.
t5x-main/docs/usage/eval.md ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluating a Model
2
+
3
+
4
+ ## Introduction
5
+
6
+ This page outlines the steps to evaluate a model with T5X on downstream tasks
7
+ defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md).
8
+
9
+ Refer to this tutorial when you have an existing model that you want to
10
+ evaluate. If you would like to fine-tune your model before evaluation, please
11
+ refer to the [fine-tuning](finetune.md) tutorial. You can run evals as part of
12
+ your fine-tuning run as well.
13
+
14
+ ## Overview
15
+
16
+ Evaluating a model with T5X consists of the following steps:
17
+
18
+ 1. Choose the model to evaluate.
19
+ 1. Choose the SeqIO Task/Mixture to evaluate the model on.
20
+ 1. Write a Gin file that configures the model, SeqIO Task/Mixture and other
21
+ details of your eval run.
22
+ 1. Launch your experiment locally or on XManager.
23
+ 1. Monitor your experiment and parse metrics.
24
+
25
+ These steps are explained in detail in the following sections. An example run
26
+ that evaluates a fine-tuned T5-1.1-Small checkpoint on the
27
+ [(Open Domain) Natural Questions benchmark](https://ai.google.com/research/NaturalQuestions/)
28
+ is also showcased.
29
+
30
+ ## Step 1: Choose a model
31
+
32
+ To evaluate a model, you need a Gin config file that defines the model params,
33
+ and the model checkpoint to load from. For this example, a T5-1.1-Small model
34
+ fine-tuned on the
35
+ [`natural_questions_open_test`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=141&rcl=370261021)
36
+ SeqIO Task will be used:
37
+
38
+ + Model checkpoint -
39
+ [`cbqa/small_ssm_nq/model.ckpt-1110000`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/cbqa/small_ssm_nq/)
40
+ + Model Gin file -
41
+ [`t5x/configs/models/t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
42
+
43
+ If you would like to fine-tune your model before evaluation, please follow the
44
+ [fine-tuning](finetune.md) tutorial, and continue to Step 2. A list of all
45
+ available pre-trained models (with model checkpoints and Gin config files) are
46
+ available in the [Models](https://github.com/google-research/t5x/blob/main/docs/models.md) documentation.
47
+
48
+ ## Step 2: Choose a SeqIO Task/Mixture
49
+
50
+ A SeqIO Task encapsulates the data source, the preprocessing logic to be
51
+ performed on the data before querying the model, the postprocessing logic to be
52
+ performed on model outputs, and the metrics to be computed given the
53
+ postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks
54
+ and enables fine-tuning a model on multiple Tasks simultaneously.
55
+
56
+ Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/),
57
+ [SuperGLUE](https://super.gluebenchmark.com/),
58
+ [WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate),
59
+ [SQUAD](https://rajpurkar.github.io/SQuAD-explorer/),
60
+ [CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been
61
+ implemented as SeqIO Tasks/Mixtures and can be used directly. These
62
+ Tasks/Mixtures are defined in
63
+ [`t5/data/tasks.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py) and
64
+ [`t5/data/mixtures.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/mixtures.py).
65
+
66
+ For the example run, you will evaluate the model on the Natural Questions
67
+ benchmark, which has been implemented as the `natural_questions_open` Task in
68
+ [`/third_party/google_research/google_research/t5_closed_book_qa/t5_cbqa/tasks.py`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=98&rcl=370261021).
69
+ Here's an example of a single row of preprocessed data from this Task:
70
+
71
+ ```python
72
+ {
73
+ 'inputs_pretokenized': 'nq question: what was the main motive of salt march',
74
+ 'inputs': [3, 29, 1824, 822, 10, 125, 47, 8, 711, 10280, 13, 3136, 10556, 1]
75
+ 'targets_pretokenized': 'challenge to British authority',
76
+ 'targets': [1921, 12, 2390, 5015, 1],
77
+ 'answers': ['challenge to British authority']
78
+ }
79
+ ```
80
+
81
+ ## Step 3: Write a Gin Config
82
+
83
+ After choosing the model and SeqIO Task/Mixture for your run, the next step is
84
+ to configure your run using Gin. If you're not familiar with Gin, reading the
85
+ [T5X Gin Primer](gin.md) is recommended.
86
+
87
+ T5X provides a Gin file that configures the T5X eval job (located at
88
+ [`t5x/configs/runs/eval.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/eval.gin)),
89
+ and expects a few params from you. These params can be specified in a separate
90
+ Gin file, or via commandline flags. Following are the required params:
91
+
92
+ + `CHECKPOINT_PATH`: This is the path to the model checkpoint (from Step 1).
93
+ For the example run, set this to
94
+ `'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'`.
95
+ + `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run eval
96
+ on (from Step 2). For the example run, set this to
97
+ `'natural_questions_open'`.
98
+ + `EVAL_OUTPUT_DIR`: A path to write eval outputs to. When launching using
99
+ XManager, this path is automatically set and can be accessed from the
100
+ XManager Artifacts page. When running locally using Blaze, you can
101
+ explicitly pass a directory using a flag. Launch commands are provided in
102
+ the next step.
103
+
104
+ In addition to the above params, you will need to import
105
+ [`eval.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/eval.gin) and the
106
+ Gin file for the model, which for the example run is
107
+ [`t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
108
+
109
+ ```gin
110
+ include 'runs/eval.gin'
111
+ include 'models/t5_small.gin'
112
+ ```
113
+
114
+ Note that the `include` statements use relative paths in this example. You will
115
+ pass an appropriate `gin_search_paths` flag to locate these files when launching
116
+ your run. Absolute paths to Gin files can also be used, e.g.
117
+
118
+ ```gin
119
+ include 't5x/configs/runs/eval.gin'
120
+ include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin'
121
+ ```
122
+
123
+ You will also need to import the Python module(s) that register SeqIO Tasks and
124
+ Mixtures used in your run. For the example run, we add `import
125
+ google_research.t5_closed_book_qa.t5_cbqa.tasks`
126
+ since it is where 'glue_v002_proportional' is registered.
127
+
128
+ If you choose a module that is not included as a dependency in the T5X trainer
129
+ [binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=76;rcl=398627055), or if you
130
+ have defined your gin config file in a location other than the
131
+ [T5X config directory](https://github.com/google-research/t5x/blob/main/t5x/configs/), you will
132
+ need to follow the instructions in the
133
+ [Advanced Topics section](#custom-t5x-binaries) to link in the custom gin file
134
+ and/or task definition.
135
+
136
+ Note that for most common Task/Mixtures, such as the `glue_v002_proportional`
137
+ used in this tutorial, the necessary modules are already included. It is also
138
+ possible to skip writing a Gin file and instead pass the params as flags when
139
+ launching the eval job (see instructions in Step 4).
140
+
141
+ Finally, your Gin file should look like this:
142
+
143
+ ```gin
144
+ include 't5x/configs/runs/eval.gin'
145
+ include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin'
146
+
147
+ # Register necessary SeqIO Tasks/Mixtures.
148
+ import google_research.t5_closed_book_qa.t5_cbqa.tasks
149
+
150
+ CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'
151
+ MIXTURE_OR_TASK_NAME = 'natural_questions_open'
152
+ ```
153
+
154
+ See
155
+ [`t5_1_1_small_cbqa_natural_questions.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/examples/eval/t5_1_1_small_cbqa_natural_questions.gin)
156
+ for this example.
157
+
158
+ In this example, we run the evaluation on one checkpoint. It is common to
159
+ evaluate with multiple checkpoints. We provide an easy way to do so *without*
160
+ having to recompile the model graph for each checkpoints. This is simply done by
161
+ adding `utils.RestoreCheckpointConfig.mode = "all"` to a gin file. Our
162
+ `t5x/configs/runs/eval.gin` uses "specific" mode.
163
+
164
+ ## Step 4: Launch your experiment
165
+
166
+ To launch your experiment locally (for debugging only; larger checkpoints may
167
+ cause issues), run the following on commandline:
168
+
169
+ ```sh
170
+ EVAL_OUTPUT_DIR="/tmp/model-eval/"
171
+ python -m t5x.eval_unfragmented \
172
+ --gin_file=t5x/google/examples/flaxformer_t5/configs/examples/eval/t5_1_1_small_cbqa_natural_questions.gin \
173
+ --gin.EVAL_OUTPUT_DIR=\"${EVAL_OUTPUT_DIR}\" \
174
+ --alsologtostderr
175
+ ```
176
+
177
+ Note that relative paths can be used to locate the gin files. For that, multiple
178
+ comma-separated paths can be passed to the `gin_search_paths` flag, and these
179
+ paths should contain all Gin files used or included in your experiment.
180
+
181
+
182
+ You can have a look inside
183
+ [`eval.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/eval.gin) to see
184
+ other useful parameters that it is possible to pass in, including dataset split,
185
+ batch size, and random seed.
186
+
187
+ ## Step 5: Monitor your experiment and parse metrics
188
+
189
+
190
+ After evaluation has completed, you can parse metrics into CSV format using the
191
+ following script:
192
+
193
+ ```sh
194
+ EVAL_OUTPUT_DIR= # from Step 4 if running locally, from XManager Artifacts otherwise
195
+ VAL_DIR="$EVAL_OUTPUT_DIR/inference_eval"
196
+ python -m t5.scripts.parse_tb \
197
+ --summary_dir="$VAL_DIR" \
198
+ --seqio_summaries \
199
+ --out_file="$VAL_DIR/results.csv" \
200
+ --alsologtostderr
201
+ ```
202
+
203
+ ## Next Steps
204
+
205
+ Now that you have successfully evaluated a model on the Natural Questions
206
+ benchmark, here are some topics you might want to explore next:
207
+
208
+ + [Running inference on a model.](infer.md)
209
+ + [Fine-tuning a model.](finetune.md)
210
+ + [Training a model from scratch.](pretrain.md)
211
+
212
+ We also touch upon a few advanced topics related to evaluations below that might
213
+ be useful, especially when customizing your eval job.
214
+
215
+ ## Advanced Topics
216
+
217
+
218
+ ### Defining a custom SeqIO Task/Mixture to evaluate on {.no-toc}
219
+
220
+ Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md).
221
+
222
+ ### Defining a custom metric to evaluate
223
+
224
+ The best way to define a custom metric is to define a new SeqIO Task/Mixture
225
+ that contains this custom metric. Please refer to the SeqIO Documentation on
226
+ [custom metrics](https://github.com/google/seqio/blob/main/README.md#metrics).
t5x-main/docs/usage/finetune.md ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine Tuning a Model
2
+
3
+
4
+ ## Introduction
5
+
6
+ This page outlines the steps to fine-tune an existing pre-trained model with T5X
7
+ on common downstream tasks defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md). This is one of
8
+ the simplest and most common use cases of T5X. If you're new to T5X, this
9
+ tutorial is the recommended starting point.
10
+
11
+ ## Overview
12
+
13
+ Fine-tuning a model with T5X consists of the following steps:
14
+
15
+ 1. Choose the pre-trained model to fine-tune.
16
+ 2. Choose the SeqIO Task/Mixture to fine-tune the model on.
17
+ 3. Write a Gin file that configures the pre-trained model, SeqIO Task/Mixture
18
+ and other details of your fine-tuning run.
19
+ 4. Launch your experiment locally or on XManager.
20
+ 5. Monitor your experiment and parse metrics.
21
+
22
+ These steps are explained in detail in the following sections. An example run
23
+ that fine-tunes a T5-small checkpoint on WMT14 English to German translation
24
+ benchmark is also showcased.
25
+
26
+ ## Step 1: Choose a pre-trained model
27
+
28
+ To use a pre-trained model, you need a Gin config file that defines the model
29
+ params, and the model checkpoint to load from. For your convenience, TensorFlow
30
+ checkpoints and Gin configs for common T5 pre-trained models have been made
31
+ available for use in T5X. A list of all the available pre-trained models (with
32
+ model checkpoints and Gin config files) are available in the
33
+ [Models](https://github.com/google-research/t5x/blob/main/docs/models.md) documentation.
34
+
35
+ For the example run, you will use the T5 1.1 Small model. The Gin file for this
36
+ model is located at
37
+ [`/t5x/examples/t5/t5_1_1/small.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin),
38
+ and the checkpoint is located at
39
+ [`gs://t5-data/pretrained_models/t5x/t5_1_1_small`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_small).
40
+
41
+ ## Step 2: Choose a SeqIO Task/Mixture
42
+
43
+ A SeqIO Task encapsulates the data source, the preprocessing logic to be
44
+ performed on the data before querying the model, the postprocessing logic to be
45
+ performed on model outputs, and the metrics to be computed given the
46
+ postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks
47
+ and enables fine-tuning a model on multiple Tasks simultaneously.
48
+
49
+ ### Standard Tasks
50
+
51
+ Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/),
52
+ [SuperGLUE](https://super.gluebenchmark.com/),
53
+ [WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate),
54
+ [SQUAD](https://rajpurkar.github.io/SQuAD-explorer/),
55
+ [CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been
56
+ implemented as SeqIO Tasks/Mixtures and can be used directly. These
57
+ Tasks/Mixtures are defined in
58
+ [`third_party/py/t5/data/tasks.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py)
59
+ and
60
+ [`third_party/py/t5/data/mixtures.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/mixtures.py).
61
+
62
+ For the example run, you will fine-tune the model on the WMT14 English to German
63
+ translation benchmark, which has been implemented as the
64
+ [`wmt_t2t_ende_v003`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py;l=209;rcl=417815592)
65
+ Task.
66
+
67
+ ### Custom Tasks
68
+
69
+ It is also possible to define your own custom task. See the
70
+ [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md) for how to do this. As a note, Tasks
71
+ defined using the
72
+ [old T5 codebase](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/dataset_providers.py)
73
+ may also be used by T5X. If using a custom Task, you will need to follow the
74
+ instructions in the [Advanced Topics section](#custom-t5x-binaries) at the end
75
+ of this tutorial to make sure the module containing your task is included.
76
+
77
+ When defining a custom task, you have the option to cache it on disk before
78
+ fine-tuning. The instructions for this are
79
+ [here](https://github.com/google/seqio/blob/main/README.md#optional-offline-caching). Caching may improve
80
+ performance for tasks with expensive pre-processing. By default, T5X expects
81
+ tasks to be cached. To finetune on a task that has not been cached, set
82
+ `--gin.USE_CACHED_TASKS=False`.
83
+
84
+ ## Step 3: Write a Gin Config
85
+
86
+ After choosing the pre-trained model and SeqIO Task/Mixture for your run, the
87
+ next step is to configure your run using Gin. If you're not familiar with Gin,
88
+ reading the [T5X Gin Primer](gin.md) is recommended.
89
+
90
+ T5X provides a Gin file that configures the T5X trainer for fine-tuning (located
91
+ at
92
+ [`t5x/configs/runs/finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin)),
93
+ and expects a few params from you. These params can be specified in a separate
94
+ Gin file, or via commandline flags. Following are the required params:
95
+
96
+ + `INITIAL_CHECKPOINT_PATH`: This is the path to the pre-trained checkpoint
97
+ (from Step 1). For the example run, set this to
98
+ `'gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000'`.
99
+ + `TRAIN_STEPS`: Number of fine-tuning steps. This includes the number of
100
+ steps that the model was pre-trained for, so make sure to add the step
101
+ number from the `INITIAL_CHECKPOINT_PATH`. For the example run, to fine-tune
102
+ for `20_000` steps, set this to `1_020_000`, since the initial checkpoint is
103
+ the `1_000_000`th step.
104
+ + `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run (from
105
+ Step 2). For the example run, set this to `'wmt_t2t_ende_v003'`.
106
+ + `TASK_FEATURE_LENGTHS`: This is a dict mapping feature key to maximum int
107
+ length for that feature. After preprocessing, features are truncated to the
108
+ provided value. For the example run, set this to `{'inputs': 256, 'targets':
109
+ 256}`.
110
+ + `MODEL_DIR`: A path to write fine-tuned checkpoints to. When launching using
111
+ XManager, this path is automatically set and can be accessed from the
112
+ XManager Artifacts page. When running locally using Blaze, you can
113
+ explicitly pass a directory using a flag. Launch commands are provided in
114
+ the next step.
115
+ + `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained
116
+ using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should
117
+ be set to `pretraining batch_size` * `pretrained target_token_length`. For
118
+ T5 and T5.1.1: `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.
119
+
120
+ In addition to the above params, you will need to include
121
+ [`finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin)
122
+ and the Gin file for the pre-trained model, which for the example run is
123
+ [`t5_1_1/small.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin).
124
+
125
+ ```gin
126
+ include 't5x/configs/runs/finetune.gin'
127
+ include 't5x/examples/t5/t5_1_1/small.gin'
128
+ ```
129
+
130
+ You will also need to import the Python module(s) that register SeqIO Tasks and
131
+ Mixtures used in your run. For the example run, we add `import t5.data.tasks`
132
+ since it is where `wmt_t2t_ende_v003` is registered.
133
+
134
+
135
+ Finally, your Gin file should look like this:
136
+
137
+ ```gin
138
+ include 't5x/configs/runs/finetune.gin'
139
+ include 't5x/examples/t5/t5_1_1/small.gin'
140
+
141
+ # Register necessary SeqIO Tasks/Mixtures.
142
+ import t5.data.tasks
143
+
144
+ MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003"
145
+ TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256}
146
+ TRAIN_STEPS = 1_020_000 # 1000000 pre-trained steps + 20000 fine-tuning steps.
147
+ DROPOUT_RATE = 0.0
148
+ INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000"
149
+ LOSS_NORMALIZING_FACTOR = 233472
150
+ ```
151
+
152
+ See
153
+ [`t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin)
154
+ for this example.
155
+
156
+
157
+ ## Step 4: Launch your experiment
158
+
159
+ To launch your experiment locally (for debugging only; larger checkpoints may
160
+ cause issues), run the following on commandline:
161
+
162
+ ```sh
163
+ MODEL_DIR="/tmp/finetune-model/"
164
+ python -m t5x.train_unfragmented \
165
+ --gin_file=t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin \
166
+ --gin.MODEL_DIR=\"${MODEL_DIR}\" \
167
+ --alsologtostderr
168
+ ```
169
+
170
+ Note that multiple comma-separated paths can be passed to the `gin_search_paths`
171
+ flag, and these paths should contain all Gin files used or included in your
172
+ experiment.
173
+
174
+
175
+ After fine-tuning has completed, you can parse metrics into CSV format using the
176
+ following script:
177
+
178
+ ```sh
179
+ MODEL_DIR= # from Step 4 if running locally, from XManager Artifacts otherwise
180
+ VAL_DIR="$MODEL_DIR/inference_eval"
181
+ python -m t5.scripts.parse_tb \
182
+ --summary_dir="$VAL_DIR" \
183
+ --seqio_summaries \
184
+ --out_file="$VAL_DIR/results.csv" \
185
+ --alsologtostderr
186
+ ```
187
+
188
+ ### Metric Explanations
189
+
190
+ By default, t5x logs many metrics to TensorBoard, many of these seem similar but
191
+ have important distinctions.
192
+
193
+ The first two graphs you will see are the `accuracy` and `cross_ent_loss`
194
+ graphs. These are the *token-level teacher-forced* accuracy and cross entropy
195
+ loss respectively. Each of these graphs can have multiple curves on them. The
196
+ first curve is the `train` curve. This is calculated as a running sum than is
197
+ then normalized over the whole training set. The second class of curves have the
198
+ form `training_eval/${task_name}`. These curves are created by running a subset
199
+ (controlled by the `eval_steps` parameter of the main train function) of the
200
+ validation split of `${task_name}` through the model and calculating these
201
+ metrics using teacher-forcing. These graphs can commonly be used to find
202
+ "failure to learn" cases and as a warning sign of overfitting, but these are
203
+ often not the final metrics one would report on.
204
+
205
+ The second set of graphs are the ones under the collapsible `eval` section in
206
+ TensorBoard. These graphs are created based on the `metric_fns` defined in the
207
+ SeqIO task. The curves on these graphs have the form
208
+ `inference_eval/${task_name}`. Values are calculated by running the whole
209
+ validation split through the model in inference mode, commonly auto-regressive
210
+ decoding or output scoring. Most likely these are the metrics that will be
211
+ reported.
212
+
213
+ More information about the configuration of the datasets used for these
214
+ different metrics can be found [here](#train-train-eval-and-infer-eval).
215
+
216
+ In summary, the metric you actually care about most likely lives under the
217
+ `eval` tab rather, than in the `accuracy` graph.
218
+
219
+ ## Next Steps
220
+
221
+ Now that you have successfully fine-tuned a pre-trained model on WMT, here are
222
+ some topics you might want to explore next:
223
+
224
+ + [Evaluating a fine-tuned model.](eval.md)
225
+ + [Running inference on a fine-tuned model.](infer.md)
226
+ + [Training a model from scratch.](pretrain.md)
227
+
228
+ We also touch upon a few advanced topics related to fine-tuning below that might
229
+ be useful, especially when customizing your fine-tuning job.
230
+
231
+ ## Advanced Topics
232
+
233
+ ### `train`, `train_eval` and `infer_eval` {.no-toc}
234
+
235
+ A
236
+ [`DatasetConfig`](https://github.com/google-research/t5x/blob/main/t5x/utils.py?l=113&rcl=375475889)
237
+ object is used to configure loading SeqIO Tasks/Mixtures for training and eval.
238
+ If you take a closer look at
239
+ [`runs/finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin),
240
+ you will see that there are three `DatasetConfig` objects defined and passed to
241
+ the train function: `train_dataset_cfg`, `train_eval_dataset_cfg`,
242
+ `infer_eval_dataset_cfg`. Here's a brief description of these configs:
243
+
244
+ + `train`: This configures the Task/Mixture that the model will be fine-tuned
245
+ on.
246
+ + `train_eval`: This configures the Task/Mixture that is used to compute
247
+ training metrics on the eval split, e.g. perplexity. These metrics are
248
+ defined in the
249
+ [`Model`](https://github.com/google-research/t5x/blob/main/t5x/models.py;l=257-267;rcl=394045248)
250
+ class and the eval fn is located
251
+ [here](https://github.com/google-research/t5x/blob/main/t5x/trainer.py;l=257;rcl=398487394).
252
+ + `infer_eval`: This configures the Task/Mixture that is used to compute
253
+ metrics on inferred model outputs (e.g., comparing decoded model outputs and
254
+ targets). These metrics are defined in the SeqIO Task/Mixture and the eval
255
+ fn is located
256
+ [here](https://github.com/google/seqio/tree/main/seqio/evaluation.py?l=423&rcl=373643592)
257
+
258
+ ### Using separate SeqIO Tasks/Mixtures for fine-tuning and eval {.no-toc}
259
+
260
+ Commonly, the same SeqIO Task/Mixture is used for training and eval. It is set
261
+ by the `MIXTURE_OR_TASK_NAME` macro in your fine-tune Gin file from Step 3
262
+ above, and is passed to `train_dataset_cfg`, `train_eval_dataset_cfg`,
263
+ `infer_eval_dataset_cfg`. The `train` split is used for training and the
264
+ `validation` split is used for evals. However, you can override these params in
265
+ your fine-tune Gin config. For example, if you want to fine-tune on all GLUE
266
+ tasks but evaluate only on GLUE STS benchmark, you can override the SeqIO
267
+ Task/Mixture used for `infer_eval` in your fine-tune Gin file as follows:
268
+
269
+ ```gin
270
+ include 'runs/finetune.gin'
271
+ include 'models/t5_small.gin'
272
+
273
+ MIXTURE_OR_TASK_NAME = 'glue_v002_proportional'
274
+ MIXTURE_OR_TASK_MODULE = 't5.data.tasks'
275
+ TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 84}
276
+ TRAIN_STEPS = 1_500_000 # includes 1_000_000 pretrain steps
277
+ INITIAL_CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/t5x/t5_small/checkpoint_1000000'
278
+ infer_eval/utils.DatasetConfig.mixture_or_task_name = 'glue_stsb_v002'
279
+ ```
280
+
281
+ Other params in `finetune.gin` can be overridden in the same way.
282
+
283
+
284
+ ### Defining a custom SeqIO Task/Mixture to fine-tune on {.no-toc}
285
+
286
+ Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md).
t5x-main/docs/usage/gin.md ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gin Primer
2
+
3
+
4
+ [Gin](https://github.com/google/gin-config/blob/main/README.md) is a lightweight configuration framework for Python,
5
+ based on dependency injection. While T5X does not employ gin in its core
6
+ libraries, it is used to configure runs of the `train`, `eval`, and `infer`
7
+ scripts. This usage is a bit different (and more limited) than how gin is
8
+ typically applied, so this primer should be useful even for those who may be
9
+ familiar with gin from other libaries (e.g., T5 or Mesh TensorFlow).
10
+
11
+ Nevertheless, you may still find it helpful to refer to the
12
+ [gin documentation](https://github.com/google/gin-config/blob/main/README.md) for more background.
13
+
14
+ [TOC]
15
+
16
+ ## Gin in T5X Scripts
17
+
18
+ Rather than plumbing run arguments and hyperparameters through via limited set
19
+ of command-line flags or a flat configuration schema, T5X's gin integration
20
+ allows you to parameterize the top-level run functions (`train`, `evaluate`, and
21
+ `infer`) as well as any object or function that is passed to them. This enables
22
+ a vast amount of flexibility over your runs without needing to modify any code
23
+ within the core T5X library.
24
+
25
+ For example, you can implement a Python class in your own codebase (e.g., a
26
+ custom model or trainer) and use gin to pass an instance of it to the T5X XM
27
+ launcher without having to fork any code. Previously you needed to implement
28
+ every experimental idea in the core library (no matter how widely used it would
29
+ be) and add a ConfigDict flag to enable/disable it, resulting in significant
30
+ code debt over time.
31
+
32
+ On the other hand, gin can sometimes be too powerful, allowing users the ability
33
+ to bind arguments throughout a codebase, which makes it difficult or impossible
34
+ to update "private" internal interfaces. However, by limiting configurability to
35
+ a single top-level function and its arguments we can better control the
36
+ configurable surface to public interfaces and user-owned code, and also avoid
37
+ unintended side effects.
38
+
39
+ ### An Example
40
+
41
+ Let's look at the `evaluate` call signature from
42
+ [eval.py](https://github.com/google-research/t5x/blob/main/t5x/eval.py) as an example:
43
+
44
+ ```py
45
+ def evaluate(*,
46
+ model: models.BaseModel,
47
+ dataset_cfg: utils.DatasetConfig,
48
+ restore_checkpoint_cfg: utils.RestoreCheckpointConfig,
49
+ partitioner: partitioning.BasePartitioner,
50
+ output_dir: str):
51
+ """Evaluation function.
52
+
53
+ Args:
54
+ model: The model object to use for inference.
55
+ dataset_cfg: Specification for the dataset to infer based on.
56
+ restore_checkpoint_cfg: Specification for the model parameter checkpoint to
57
+ load.
58
+ partitioner: The partitioner for the model parameters and
59
+ data across devices.
60
+ output_dir: Path to directory to write temporary files and final results.
61
+ """
62
+ ...
63
+ ```
64
+
65
+ In the binary, the user-provided gin configuration file will be parsed. It
66
+ specifies which values should be bound to the `evaluate` argument, after which
67
+ we can directly call the fully-bound function without any arguments. Basically,
68
+ we are creating a custom closure of `evaluate` (a la `functools.partial`) but
69
+ specifying the arguments via gin instead of Python.
70
+
71
+ Furthermore, this ability to bind custom arguments is recursive. Not only can we
72
+ bind the arguments of `evaluate`, but we can also bind the constructor and
73
+ method arguments of the instance of `models.BaseModel` that we pass to
74
+ `evaluate`.
75
+
76
+ Let's now look at an example of a gin configuration for parameterizing
77
+ `evaluate`, specifically evaluating a
78
+ [T5 model fine-tuned for closed book question answering](http://goo.gle/t5-cbqa)
79
+ on [Natural Questions Open](https://ai.google.com/research/NaturalQuestions):
80
+
81
+ ```py
82
+ from __gin__ import dynamic_registration
83
+
84
+ import __main__ as eval_script
85
+ from t5x import models
86
+ from t5x import partitioning
87
+ from t5x import utils
88
+
89
+ MODEL = %gin.REQUIRED
90
+
91
+ eval_script.evaluate:
92
+ model = %MODEL
93
+ output_dir = '/tmp/t5x_eval'
94
+ dataset_cfg = @utils.DatasetConfig()
95
+ partitioner = @partitioning.PjitPartitioner()
96
+ restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
97
+
98
+ # Load model with overrides.
99
+ include 'models/t5_large.gin'
100
+ models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
101
+
102
+ utils.DatasetConfig:
103
+ mixture_or_task_name = 'natural_questions_open'
104
+ split = 'test'
105
+ task_feature_lengths = None
106
+ batch_size = 32
107
+ shuffle = False
108
+ seed = 0
109
+ use_cached = False
110
+ pack = False
111
+ use_custom_packing_ops = False
112
+ module = 'google_research.t5_closed_book_qa.t5_cbqa.tasks'
113
+
114
+ partitioning.PjitPartitioner:
115
+ num_partitions = 1
116
+
117
+ utils.RestoreCheckpointConfig:
118
+ mode = 'specific'
119
+ path = 'gs://t5-data/pretrained_models/cbqa/large_ssm_nqo'
120
+ assignment_map = None
121
+ strict = True
122
+ dtype = None
123
+ ```
124
+
125
+ Let's go through this block-by-block.
126
+
127
+ ```py
128
+ from __gin__ import dynamic_registration
129
+ ```
130
+
131
+ The first line imports a new gin feature (see cl/372624800 for more details) to
132
+ allow us to register functions and objects for configuration from within the gin
133
+ file itself without having to modify or decorate functions from the imported
134
+ packages.
135
+
136
+ ```py
137
+ import __main__ as eval_script
138
+ from t5x import models
139
+ from t5x import utils
140
+ ```
141
+
142
+ The second block imports the modules containing the components we plan to
143
+ configure in this file and is required for dynamic registration. Note that only
144
+ those functions and objects that we specify below will actually be configured,
145
+ not everything in the module. Also, as is the case in Python, the binary module
146
+ is referred as `__main__`, although we rename it to `eval_script` for clarity in
147
+ the rest of the config.
148
+
149
+ ```py
150
+ MODEL = %gin.REQUIRED
151
+ ```
152
+
153
+ The third block creates a
154
+ [gin macro](https://github.com/google/gin-config/tree/master/docs/index.md#gin-macros)
155
+ (essentially a lazy reference) and for now sets it to refer to the special macro
156
+ `gin.REQUIRED`, which will cause a failure during parsing of the configuration
157
+ if not updated via a later assignment in the config file or command-line flags
158
+ (see [below](#command-line-usage)).
159
+
160
+ ```py
161
+ eval_script.evaluate:
162
+ model = %MODEL
163
+ output_dir = '/tmp/t5x_eval'
164
+ dataset_cfg = @utils.DatasetConfig()
165
+ partitioner = @partitioning.PjitPartitioner()
166
+ restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
167
+ ```
168
+
169
+ The fourth block specifies the binding for the `evaluate` function. For `model`,
170
+ we pass the value of the `MODEL` macro (to be defined later). For `output_dir`
171
+ we pass a string path. For `dataset_cfg`, `restore_checkpoint_cfg`, and
172
+ `partitioner`, we pass instantiations of `DatasetConfig`,
173
+ `RestoreCheckpointConfig`, and `PjitPartitioner`, which are defined in
174
+ [utils.py](https://github.com/google-research/t5x/blob/main/t5x/utils.py) and
175
+ [partitioning.py](https://github.com/google-research/t5x/blob/main/t5x/partitioning.py)
176
+ respectively. The '@' prefix tells gin that the following is a configured
177
+ function or class, and the '()' suffix signifies that it should be called (in
178
+ the cases of class, this means calling the constructor). If we wanted to pass in
179
+ the closure (or a partially bound) function instead of its return value, we
180
+ would leave off the parentheses.
181
+
182
+ The remainder of the file deals with defining the `MODEL` macro and fully
183
+ binding these constructors.
184
+
185
+ ```py
186
+ # Load model with overrides.
187
+ include 't5x/examples/t5/t5_1_1/large.gin'
188
+ models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
189
+ ```
190
+
191
+ Although we could define `MODEL = model.EncoderDecoderModel()` here, we prefer
192
+ to create a separate gin file that defines it. This makes it easier to reuse
193
+ parts of the common configurations. All of the bindings in the newly included
194
+ file are read and override any conflicting ones defined so far in this file.
195
+ It's equivalent to copy and pasting the contents of the included file at this
196
+ location in the config. If you want to see how the model itself is instantiated,
197
+ you can refer to
198
+ [t5_1_1/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/large.gin)
199
+ (which simply overrides a few values from
200
+ [t5_1_1/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/base.gin)).
201
+
202
+ The final line of this block shows an example of how you can modify the default
203
+ arguments of the `EncoderDecoderModel` instance referenced by `%MODEL`, in this
204
+ case changing the default beam size it will use during prediction. Notice that
205
+ since we are only binding one argument here, we choose to write it on a single
206
+ line instead of using the block binding syntax used elsewhere in the file.
207
+
208
+ ```py
209
+ utils.DatasetConfig:
210
+ mixture_or_task_name = 'natural_questions_open'
211
+ split = 'test'
212
+ task_feature_lengths = None
213
+ batch_size = 32
214
+ shuffle = False
215
+ seed = 0
216
+ use_cached = False
217
+ pack = False
218
+ use_custom_packing_ops = False
219
+ module = 'google_research.t5_closed_book_qa.t5_cbqa.tasks'
220
+
221
+ partitioning.PjitPartitioner:
222
+ num_partitions = 1
223
+
224
+ utils.RestoreCheckpointConfig:
225
+ mode = 'specific'
226
+ path = 'gs://t5-data/pretrained_models/cbqa/large_ssm_nqo'
227
+ assignment_map = None
228
+ strict = True
229
+ dtype = None
230
+ ```
231
+
232
+ The last 3 blocks are fairly straightforward. They are effectively setting the
233
+ attributes of these dataclasses by binding values to their constructors that
234
+ will be used when they are instantiated and passed to `evaluate`, as specified
235
+ in the fourth block.
236
+
237
+ ### Scoping
238
+
239
+ The above example lacks one key component of gin:
240
+ [scopes](https://github.com/google/gin-config/blob/main/README.md#4-configuring-the-same-function-in-different-ways-scopes).
241
+
242
+ What happens if you need to use a class or function multiple times but with
243
+ different bound values?
244
+
245
+ A clear example of this is in the top-level `train` function (in
246
+ [train.py](https://github.com/google-research/t5x/blob/main/t5x/train.py)). The call signature
247
+ includes 3 different instances of `utils.DatasetConfig`: one for the train
248
+ dataset, one for the "train-eval" dataset (used for evaluation with teacher
249
+ forcing), and one for the "infer-eval" dataset (used for evaluation with
250
+ inference/decoding).
251
+
252
+ The solution is to prefix each instance with a unique identifier both when
253
+ specifying where it is to be passed to `train` and when binding its arguments.
254
+ For example, the gin file might look like the following (skipping the irrelevant
255
+ bits):
256
+
257
+ ```py
258
+ ...
259
+
260
+ train_script.train:
261
+ train_dataset_cfg = @train/utils.DatasetConfig()
262
+ train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
263
+ infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
264
+ ...
265
+
266
+ train/utils.DatasetConfig:
267
+ mixture_or_task_name = 'train_mixture'
268
+ split = 'train'
269
+ ...
270
+
271
+ train_eval/utils.DatasetConfig:
272
+ mixture_or_task_name = 'eval_mixture'
273
+ split = 'validation'
274
+ ...
275
+
276
+ infer_eval/utils.DatasetConfig:
277
+ mixture_or_task_name = 'eval_mixture'
278
+ split = 'test'
279
+ ...
280
+ ```
281
+
282
+ We have therefore configured 3 different scoped-versions of
283
+ `utils.DatasetConfig` producing 3 separate instances that are passed to `train`.
284
+
285
+ Note that these three scopes will all inherit from the base scope, so if you
286
+ want to set a shared binding, you may directly configure `utils.DatasetConfig`
287
+ without a scope prefix.
288
+
289
+ ## Command-Line Usage
290
+
291
+ So now that you have a gin config, how do you pass it to the script? There are
292
+ two ways: gin files and override flags.
293
+
294
+ 1. **Gin Files** You have already seen an example of a gin file above. You can
295
+ specify the gin file(s) to use in your script via the `--gin_file` flag. If
296
+ you want to load multiple gin files, you can set the flag multiple times and
297
+ the files will be loaded in order, with the second potentially overriding
298
+ the first when there are conflicts. It is possible to supply a
299
+ comma-separate list of search prefixes via `--gin_search_paths` and then
300
+ only specify the relative path to the `--gin_file` flags. However, we
301
+ strongly recommend against using `--gin_search_paths`. Using absolute paths
302
+ via the `--gin_file` flags will reduce sources of ambiguity and improve the
303
+ consistency of your scripts.
304
+
305
+ 1. **Override Flags** Gin flags allow for more fine-grained overrides of any
306
+ configurable aspect of your run. These flags follow the single-line binding
307
+ format from the above example with the addition of a `--gin.` prefix. For
308
+ example, if you want to override the dataset shuffling, you can set
309
+ `--gin.utils.DatasetConfig.shuffle=False`. In the train setting where there
310
+ are multiple datasets, you must supply the appropriate scope, e.g.,
311
+ `--gin.train/utils.DatasetConfig.shuffle=False`. These bindings are
312
+ processed in order *after* the gin files are loaded, and therefore overwrite
313
+ any previously assigned value in the gin files.
314
+
315
+ **Note:** when supplying a string, dict, list, or tuple value via a flag, you
316
+ must put it in quotes. In the case of strings, it requires escaped quotes
317
+ (`\"<string>\"`). For example: `--gin.utils.DatasetConfig.split=\"validation\"`,
318
+ `--gin.utils.DatasetConfig.task_feature_lengths="{'inputs': 512, 'targets':
319
+ 84}"`, and `--gin.dense.MlpBlock.activations="('dense', 'gelu')"`
320
+
321
+ ### An Example
322
+
323
+ An example where you may need multiple files is with the `train` script.
324
+
325
+ You can first specify which model you want to train by supplying a gin file
326
+ containing its definition, for example:
327
+ [t5_1_1/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin).
328
+
329
+ You may then specify a run config that supplies some of the common defaults. For
330
+ example, if you are doing pretraining you can use
331
+ [runs/pretrain.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/pretrain.gin),
332
+ and if you are doing finetuning, you can use
333
+ [runs/finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin).
334
+
335
+ We can apply these two files with the following command:
336
+
337
+ ```sh
338
+ python -m t5x.train_unfragmented \
339
+ --gin_file=t5x/examples/t5/t5_1_1/small.gin \
340
+ --gin_file=t5x/configs/runs/finetune.gin \
341
+ --logtostderr
342
+ ```
343
+
344
+ However, running this command will give you an error like the following:
345
+
346
+ ```sh
347
+ ValueError: MODEL_DIR/macro.value set to `%gin.REQUIRED` but not subsequently overridden.
348
+ ```
349
+
350
+ This is because the config still includes some `gin.REQUIRED` macros that you'll
351
+ need to override with the details of your run. At the top of
352
+ [runs/finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin)
353
+ you'll see the list of required overrides, which we will populate for finetuning
354
+ on WMT in the updated launch command here:
355
+
356
+ ```sh
357
+ python -m t5x.train_unfragmented \
358
+ --gin_file=t5x/examples/t5/t5_1_1/small.gin \
359
+ --gin_file=t5x/configs/runs/finetune.gin \
360
+ --gin.MIXTURE_OR_TASK_NAME=\"wmt_t2t_ende_v003\" \
361
+ --gin.MIXTURE_OR_TASK_MODULE=\"t5.data.mixtures\" \
362
+ --gin.TASK_FEATURE_LENGTHS="{'inputs': 256, 'targets': 256}" \
363
+ --gin.TRAIN_STEPS=1_020_000 \
364
+ --gin.MODEL_DIR=\"/tmp/t5_1_1_base_finetune_gin\" \
365
+ --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000\" \
366
+ --logtostderr
367
+ ```
368
+
369
+ Note you may still override any registered bindings. For example, to disable
370
+ inference evaluation you may add `--gin.train.infer_eval_dataset_cfg=None`.
371
+
372
+ ### A File-only Example
373
+
374
+ At the beginning of the primer, we saw a fully-specified run config. We can do
375
+ something similar with the previous example to create a self-contained run
376
+ configuration.
377
+ [t5_1_1/examples/small_wmt_finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin)
378
+ is just such an example that allows you to exactly duplicate the previous launch
379
+ command simply by calling:
380
+
381
+ ```sh
382
+ python -m t5x.train_unfragmented \
383
+ --gin_file=t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin \
384
+ --gin.MODEL_DIR=\"/tmp/t5_1_1_small_finetune_gin\" \
385
+ --logtostderr
386
+ ```
387
+
388
+ ## Logging
389
+
390
+ After your gin files and flag overrides are parsed, the complete configuration
391
+ will be logged to INFO, written to `config.gin` in the output directory, and
392
+ added to a TensorBoard summary.
393
+
394
+ It is highly recommended that you review this generated config to ensure that
395
+ your overrides are working as expected.
t5x-main/docs/usage/gpu-usage.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPU Scripts
2
+
3
+ # Warning!
4
+ An updated version of T5x with optimized GPU performance (18-80% perf gains!) and new features, including FP8 with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) and H100 support can be found here: [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x).
5
+ -----
6
+ **NVIDIA no longer recommends using this repository and won't be updating it further.**
7
+ -----
8
+
9
+ The [t5x/contrib/gpu](../../t5x/contrib/gpu) directory contains scripts optimized for GPU usage.
10
+
11
+ Install with `pip install -r pile_requirements.txt` to get all pile dependencies.
12
+
13
+ ## Building the container
14
+ The Dockerfile in `t5x/contrib/gpu` given will build a container with all gpu/pile dependencies. It can be built with `t5x/contrib/gpu/docker/build.sh <name>`
15
+
16
+ ## Running interactively
17
+ Note: this should only be done with singlenode jobs and/or for downloading the pile. Use `t5x/contrib/gpu/docker/interactive_pull_and_launch.sh`. This takes arguments for the URL to pull a container from and the location of the dataset directory to mount. For example:
18
+
19
+ `t5x/contrib/gpu/docker/interactive_pull_and_launch.sh [URL] /my/dataset/dir`
20
+
21
+ ## Downloading The Pile
22
+ Run `download_the_pile.py` to download the pile. It will download to the directory set in the environment variable: `TFDS_DATA_DIR`. After that, set the `TFDS_DATA_DIR` to the same directory in your scripts to use.
23
+
24
+ ## Single Node runs
25
+ Pretraining and Finetuning can be done with `singlenode_*.sh`. These will build a T5X model with the Adam optimizer and relevant parameters. These will allow multi-gpu on one host.
26
+
27
+ ## Multi Node runs
28
+ For a SLURM+pyxis cluster, `example*.sub` files provide example slurm submit files (edit with your details), which call `multiprocess*.sh` to execute training. You can add a binding script in the `.sub` file for your cluster, or remove it entirely (dropping some throughput)
29
+
30
+ ## Convergence
31
+ For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2048 for all other models, where GBS is defined as #GPUs * BS/GPU / Tensor Parallel(TP). Below are example (tested) hardware topologies on NVIDIA DGX A100 (8x A100 80G) nodes.
32
+
33
+ | size | #GPUs | TP | BS / GPU | Sequences/Sec | Estimated Walltime | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log |
34
+ | ---- | ----- | ----- | -------- | ------------- | ------------------ | ------------------ | ------------------ | --------------- |
35
+ | small| 8 | 1 | 256 | ~3168 | 7.48 days | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) |
36
+ | large| 64 | 1 | 32 | ~3886 | 6.10 days | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) |
37
+ | xl | 256 | 1 | 8 | ~3652 | 6.49 days | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) |
38
+ | xxl | 512 | 8 | 36 | ~1346 | 19.81 days | N/A(partial run) | N/A(partial run) | N/A(partial run)|
39
+
40
+ Note: Convergence (as shown in log) was not necessarily done with the hardware topology listed, but the listed topology is tested. Estimated Walltime is calculated assuming full throughput (seq/sec) continuously. In practice, there are compilation overheads at the beginning of each run/restart(in cluster settings) + checkpointing overheads (if any).
41
+
42
+ (More perf improvements coming soon!)
43
+
44
+ Other hyperparameters are specified in the associated pile `gin` files in the `contrib/gpu/t5/t5_1_1/examples` directory.
45
+
46
+ ## Pretraining run commands
47
+
48
+ ### Singlenode
49
+ small:
50
+
51
+ `t5x/contrib/gpu/t5/scripts_gpu/singlenode_pretrain_pile.sh small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR} {GRADIENT_ACCUMULATION (1 by default)}`
52
+
53
+ Finetuning:
54
+ MNLI v2:
55
+ `t5x/contrib/gpu/t5/scripts_gpu/singlenode_ft_frompile.sh mnli2 small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR(to restore pretrained checkpoint from)} {GRADIENT_ACCUMULATION}`
56
+
57
+
58
+ ### Multinode
59
+ Arguments are as such:
60
+
61
+ `sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}`
62
+
63
+ small:
64
+
65
+ `sbatch -N 1 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub small bfloat16 8 256 {MODEL_DIR} 1 1`
66
+
67
+ large:
68
+
69
+ `sbatch -N 8 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub large bfloat16 8 32 {MODEL_DIR} 1 1`
70
+
71
+ xl:
72
+
73
+ `sbatch -N 32 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub xl bfloat16 8 8 {MODEL_DIR} 1 1`
74
+
75
+ Finetuning commands simply change the script and have an additional `{FT_TASK}` as the first argument (along with relevant hyperparameter changes). Your `MODEL_DIR` should contain the pretrained checkpoint to restore from.
76
+
77
+ MNLI v2:
78
+
79
+ `sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub mnli2 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}`
80
+
81
+ SQuAD v1.1
82
+
83
+ `sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub squad1 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}`
84
+
85
+ On all finetuning runs, we use a Global Batch Size of 128 with bfloat16 precision.
86
+
87
+ WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up.
t5x-main/docs/usage/index.rst ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ T5X Usage Guides
2
+ ================
3
+
4
+ .. toctree::
5
+ :maxdepth: 2
6
+
7
+ pretrain.md
8
+ finetune.md
9
+ eval.md
10
+ infer.md
11
+ auxiliary.md
12
+ decoding.md
13
+ metrics.md
14
+ partitioning.md
15
+ gin.md
16
+
t5x-main/docs/usage/infer-files.md ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running inference on a Model
2
+
3
+
4
+ ## Introduction
5
+
6
+ This page outlines the steps to run inference a model with T5X on files
7
+ containing
8
+ [TensorFlow Examples](https://www.tensorflow.org/api_docs/python/tf/train/Example).
9
+
10
+ ## Overview
11
+
12
+ Running inference on a model with T5X using TF Example files consists of the
13
+ following steps:
14
+
15
+ 1. Choose the model to run inference on.
16
+ 1. Choose the TF Example files to run inference on.
17
+ 1. Write a Gin file that configures the model, file source and other details of
18
+ your inference run.
19
+ 1. Launch your experiment locally or on XManager.
20
+ 1. Monitor your experiment and access predictions.
21
+
22
+ These steps are explained in detail in the following sections. An example run
23
+ that runs inference on a fine-tuned T5-1.1-Small checkpoint on `tfrecord` files
24
+ containing the
25
+ [(Open Domain) Natural Questions benchmark](https://ai.google.com/research/NaturalQuestions/)
26
+ is also showcased.
27
+
28
+ ## Step 1: Choose a model
29
+
30
+ To run inference on a model, you need a Gin config file that defines the model
31
+ params, and the model checkpoint to load from. For this example, a T5-1.1-Small
32
+ model fine-tuned on the
33
+ [`natural_questions_open_test`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=141&rcl=370261021)
34
+ SeqIO Task will be used:
35
+
36
+ + Model checkpoint -
37
+ [`cbqa/small_ssm_nq/model.ckpt-1110000`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/cbqa/small_ssm_nq/)
38
+ + Model Gin file -
39
+ [`models/t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
40
+
41
+ If you would like to fine-tune your model before inference, please follow the
42
+ [fine-tuning](finetune.md) tutorial, and continue to Step 2.
43
+
44
+ ## Step 2: Choose a TF Example file source
45
+
46
+ T5X supports running inference on `tfrecord`, `recordio` and `sstable` files
47
+ containing TF Examples. For the example run, you will run inference on
48
+ `tfrecord` files containing the `'natural_questions_open'` dataset located here:
49
+ `/path/to/tfds/data/dir/natural_questions_open/1.0.0/natural_questions_open-validation.tfrecord*`.
50
+ Here's an example of a single row of data from this file (you can explore this
51
+ file further using [GQUI](http://shortn/_oNuDhg7jwN)):
52
+
53
+ ```json
54
+ { # (tensorflow.Example) size=101B
55
+ features: { # (tensorflow.Features) size=99B
56
+ feature: { # (tensorflow.Features.FeatureEntry) size=27B
57
+ key: "answer" # size=6
58
+ value: { # (tensorflow.Feature) size=17B
59
+ bytes_list: { # (tensorflow.BytesList) size=15B
60
+ value: [ "Jason Flemyng" ] # size=13
61
+ } # features.feature[0].value.bytes_list
62
+ } # features.feature[0].value
63
+ } # features.feature[0]
64
+ feature: { # (tensorflow.Features.FeatureEntry) size=68B
65
+ key: "question" # size=8
66
+ value: { # (tensorflow.Feature) size=56B
67
+ bytes_list: { # (tensorflow.BytesList) size=54B
68
+ value: [ "who played hyde in league of extraordinary gentlemen" ] # size=52
69
+ } # features.feature[1].value.bytes_list
70
+ } # features.feature[1].value
71
+ } # features.feature[1]
72
+ } # features
73
+ }
74
+ ```
75
+
76
+ ## Step 3: Write a Gin Config
77
+
78
+ After choosing the model and file source for your run, the next step is to
79
+ configure your run using Gin. If you're not familiar with Gin, reading the
80
+ [T5X Gin Primer](gin.md) is recommended. T5X provides a Gin file that configures
81
+ the T5X inference job (located at
82
+ [`t5x/configs/runs/infer_from_tfexample_file.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer_from_tfexample_file.gin))
83
+ to run inference on TF Example files, and expects a few params from you. These
84
+ params can be specified in a separate Gin file, or via commandline flags.
85
+ Following are the required params:
86
+
87
+ + `CHECKPOINT_PATH`: This is the path to the model checkpoint (from Step 1).
88
+ For the example run, set this to
89
+ `'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'`.
90
+ + `TF_EXAMPLE_FILE_PATHS`: This is a list of paths or glob patterns to read TF
91
+ Examples from. For the example run, set this to
92
+ `['/path/to/tfds/data/dir/natural_questions_open/1.0.0/natural_questions_open-validation.tfrecord*']`.
93
+ + `TF_EXAMPLE_FILE_TYPE`: This is the TF Example file format. Currently
94
+ supported file formats are `tfrecord`, `recordio` and `sstable`. For the
95
+ example run, set this to `'tfrecord'`.
96
+ + `FEATURE_LENGTHS`: This is a dict mapping feature key to maximum int length
97
+ for that feature. the TF Example features are truncated to the provided
98
+ value. For the example run, set this to `{'inputs': 38, 'targets': 18}`,
99
+ which is the maximum token length for the test set.
100
+ + `INFER_OUTPUT_DIR`: A path to write inference outputs to. When launching
101
+ using XManager, this path is automatically set and can be accessed from the
102
+ XManager Artifacts page. When running locally using Blaze, you can
103
+ explicitly pass a directory using a flag. Launch commands are provided in
104
+ the next step.
105
+
106
+ In addition to the above params, you may also need to override the
107
+ `create_task_from_tfexample_file.inputs_key` param based on the data format (it
108
+ is set to `'inputs'` by default. For the example run, the `'question'` key
109
+ contains the input (see Step 2), so add the following to your Gin config:
110
+
111
+ ```gin
112
+ create_task_from_tfexample_file.inputs_key = 'question'
113
+ ```
114
+
115
+ Additionally, you will need to import the
116
+ [`infer_from_tfexample_file.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer_from_tfexample_file.gin)
117
+ and the Gin file for the model, which for the example run is
118
+ [`t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
119
+
120
+ ```gin
121
+ include 'runs/infer_from_tfexample_file.gin'
122
+ include 'models/t5_1_1_small.gin'
123
+ ```
124
+
125
+ Note that the `include` statements use relative paths in this example. You will
126
+ pass an appropriate `gin_search_paths` flag to locate these files when launching
127
+ your run. Absolute paths to Gin files can also be used, e.g.
128
+
129
+ ```gin
130
+ include 't5x/configs/runs/infer_from_tfexample_file.gin'
131
+ include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin'
132
+ ```
133
+
134
+ Finally, your Gin file should look like this:
135
+
136
+ ```gin
137
+ include 'runs/infer_from_tfexample_file.gin'
138
+ include 'models/t5_1_1_small.gin'
139
+
140
+ CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'
141
+ TF_EXAMPLE_FILE_PATHS = ['/path/to/tfds/data/dir/natural_questions_open/1.0.0/natural_questions_open-validation.tfrecord*']
142
+ TF_EXAMPLE_FILE_TYPE = 'tfrecord'
143
+ FEATURE_LENGTHS = {'inputs': 38, 'targets': 18}
144
+ create_task_from_tfexample_file.inputs_key = 'question'
145
+ ```
146
+
147
+ See
148
+ [`t5x/configs/examples/inference/t5_1_1_small_cbqa_natural_questions_tfexample.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions_tfexample.gin)
149
+ for this example. Make sure that your Gin file is linked as a data dependency to
150
+ the T5X inference
151
+ [binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=74;rcl=398627055). If your
152
+ Gin file is not included, see the
153
+ [Advanced Topics section](#custom-t5x-binaries) at the end of this tutorial for
154
+ instructions to add it, or skip writing a Gin file and pass the above params as
155
+ flags when launching the inference job (see instructions in Step 4).
156
+
157
+ ## Step 4: Launch your experiment
158
+
159
+ To launch your experiment locally (for debugging only; larger checkpoints may
160
+ cause issues), run the following on commandline:
161
+
162
+ ```sh
163
+ INFER_OUTPUT_DIR="/tmp/model-infer/"
164
+ python -m t5x.infer_unfragmented \
165
+ --gin_file=t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions_tfexample.gin \
166
+ --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
167
+ --alsologtostderr
168
+ ```
169
+
170
+ Note that multiple comma-separated paths can be passed to the `gin_search_paths`
171
+ flag, and these paths should contain all Gin files used or included in your
172
+ experiment.
173
+
174
+
175
+ After inference has completed, you can view predictions in the `jsonl` files in
176
+ the output dir. JSON data is written in chunks and combined at the end of the
177
+ inference run. Refer to [Sharding](#sharding) and
178
+ [Checkpointing](#checkpointing) sections for more details.
179
+
180
+ ## Next Steps
181
+
182
+ Now that you have successfully run inference on a model, here are some topics
183
+ you might want to explore next:
184
+
185
+ + [Fine-tuning a model.](finetune.md)
186
+ + [Evaluating a model.](eval.md)
187
+ + [Training a model from scratch.](pretrain.md)
188
+
189
+ We also touch upon a few advanced topics related to inference below that might
190
+ be useful, especially when customizing your inference job.
191
+
192
+ ## Advanced Topics
193
+
194
+ ### Dataset Sharding {#sharding .no-toc}
195
+
196
+ You can run inference in parallel across multiple TPU slices by setting the
197
+ `num_shards` flag when running using XManager. When `num_shards > 1`, the
198
+ dataset is interleaved among the shards and the predictions are combined in the
199
+ end; hence the order of examples in the data source and the predictions in the
200
+ output json files will not match (order is guaranteed to match for `num_shards =
201
+ 1` or the number of input file shards).
202
+
203
+ ### Dataset Checkpointing {#checkpointing .no-toc}
204
+
205
+ You can control dataset checkpointing frequency by overriding the
206
+ `infer.checkpoint_period` in
207
+ [runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin),
208
+ which is set to `100` by default. This means that the dataset is checkpointed
209
+ after running inferences on `checkpoint_period` batches (batches, not examples;
210
+ you can control batch size by overriding `utils.DatasetConfig.batch_size` in
211
+ [runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin), it
212
+ is set to `32` by default).
213
+
214
+
215
+ ### Defining a custom SeqIO Task/Mixture to run inference on {.no-toc}
216
+
217
+ Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md).
t5x-main/docs/usage/infer-seqio.md ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running inference on a Model
2
+
3
+
4
+ ## Introduction
5
+
6
+ This page outlines the steps to run inference a model with T5X on Tasks/Mixtures
7
+ defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md).
8
+
9
+ ## Overview
10
+
11
+ Running inference on a model with T5X using SeqIO Task/Mixtures consists of the
12
+ following steps:
13
+
14
+ 1. Choose the model to run inference on.
15
+ 1. Choose the SeqIO Task/Mixture to run inference on.
16
+ 1. Write a Gin file that configures the model, SeqIO Task/Mixture and other
17
+ details of your inference run.
18
+ 1. Launch your experiment locally or on XManager.
19
+ 1. Monitor your experiment and access predictions.
20
+
21
+ These steps are explained in detail in the following sections. An example run
22
+ that runs inference on a fine-tuned T5-1.1-Small checkpoint on the
23
+ [(Open Domain) (Open Domain) Natural Questions benchmark](https://ai.google.com/research/NaturalQuestions/)
24
+ is also showcased.
25
+
26
+ ## Step 1: Choose a model
27
+
28
+ To run inference on a model, you need a Gin config file that defines the model
29
+ params, and the model checkpoint to load from. For this example, a T5-1.1-Small
30
+ model fine-tuned on the
31
+ [`natural_questions_open_test`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=141&rcl=370261021)
32
+ SeqIO Task will be used:
33
+
34
+ + Model checkpoint -
35
+ [`cbqa/small_ssm_nq/model.ckpt-1110000`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/cbqa/small_ssm_nq/)
36
+ + Model Gin file -
37
+ [`models/t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
38
+
39
+ If you would like to fine-tune your model before inference, please follow the
40
+ [fine-tuning](finetune.md) tutorial, and continue to Step 2.
41
+
42
+ ## Step 2: Choose a SeqIO Task/Mixture
43
+
44
+ A SeqIO Task encapsulates the data source, the preprocessing logic to be
45
+ performed on the data before querying the model, the postprocessing logic to be
46
+ performed on model outputs, and the metrics to be computed given the
47
+ postprocessed outputs and targets (for inference, post-processing and metrics
48
+ are irrelevant). A SeqIO Mixture denotes a collection of Tasks and enables
49
+ fine-tuning a model on multiple Tasks.
50
+
51
+ Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/),
52
+ [SuperGLUE](https://super.gluebenchmark.com/),
53
+ [WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate),
54
+ [SQUAD](https://rajpurkar.github.io/SQuAD-explorer/),
55
+ [CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been
56
+ implemented as SeqIO Tasks/Mixtures and can be used directly. These
57
+ Tasks/Mixtures are defined in
58
+ [`third_party/py/t5/data/tasks.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py)
59
+ and
60
+ [`third_party/py/t5/data/mixtures.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/mixtures.py).
61
+
62
+ For the example run, you will run inference on the (Open Domain) Natural
63
+ Questions benchmark, which has been implemented as the `natural_questions_open`
64
+ Task in
65
+ [`/third_party/google_research/google_research/t5_closed_book_qa/t5_cbqa/tasks.py`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=98&rcl=370261021).
66
+ Here's an example of a single row of preprocessed data from this Task:
67
+
68
+ ```json
69
+ {
70
+ 'inputs_pretokenized': 'nq question: what was the main motive of salt march',
71
+ 'inputs': [3, 29, 1824, 822, 10, 125, 47, 8, 711, 10280, 13, 3136, 10556, 1]
72
+ 'targets_pretokenized': 'challenge to British authority',
73
+ 'targets': [1921, 12, 2390, 5015, 1],
74
+ 'answers': ['challenge to British authority']
75
+ }
76
+ ```
77
+
78
+ ## Step 3: Write a Gin Config
79
+
80
+ After choosing the model and SeqIO Task/Mixture for your run, the next step is
81
+ to configure your run using Gin. If you're not familiar with Gin, reading the
82
+ [T5X Gin Primer](gin.md) is recommended. T5X provides a Gin file that configures
83
+ the T5X inference job (located at
84
+ [`runs/infer.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin)) to
85
+ run inference on SeqIO Task/Mixtures, and expects a few params from you. These
86
+ params can be specified in a separate Gin file, or via commandline flags.
87
+ Following are the required params:
88
+
89
+ + `CHECKPOINT_PATH`: This is the path to the model checkpoint (from Step 1).
90
+ For the example run, set this to
91
+ `'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'`.
92
+ + `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run
93
+ inference on (from Step 2). For the example run, set this to
94
+ `'natural_questions_open'`.
95
+ + `MIXTURE_OR_TASK_MODULE`: This is the Python module that contains the SeqIO
96
+ Task or Mixture. For the example run, set this to
97
+ `'google_research.t5_closed_book_qa.t5_cbqa.tasks'`.
98
+ Note that this module must be included as a dependency in the T5X inference
99
+ [binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=74;rcl=398627055). Most
100
+ common Task modules, including `t5_closed_book_qa`, are already included. If
101
+ your module is not included, see the
102
+ [Advanced Topics section](#custom-t5x-binaries) at the end of this tutorial
103
+ for instructions to add it.
104
+ + `TASK_FEATURE_LENGTHS`: This is a dict mapping feature key to maximum length
105
+ for that feature. After preprocessing, features are truncated to the
106
+ provided value. For the example run, set this to `{'inputs': 38, 'targets':
107
+ 18}`, which is the maximum token length for the test set.
108
+ + `INFER_OUTPUT_DIR`: A path to write inference outputs to. When launching
109
+ using XManager, this path is automatically set and can be accessed from the
110
+ XManager Artifacts page. When running locally using Blaze, you can
111
+ explicitly pass a directory using a flag. Launch commands are provided in
112
+ the next step.
113
+
114
+ In addition to the above params, you will need to import
115
+ [`infer.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin) and the
116
+ Gin file for the model, which for the example run is
117
+ [`t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
118
+
119
+ ```gin
120
+ include 'runs/infer.gin'
121
+ include 'models/t5_small.gin'
122
+ ```
123
+
124
+ Note that the `include` statements use relative paths in this example. You will
125
+ pass an appropriate `gin_search_paths` flag to locate these files when launching
126
+ your run. Absolute paths to Gin files can also be used, e.g.
127
+
128
+ ```gin
129
+ include 't5x/configs/runs/infer.gin'
130
+ include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin'
131
+ ```
132
+
133
+ Finally, your Gin file should look like this:
134
+
135
+ ```gin
136
+ include 'runs/infer.gin'
137
+ include 'models/t5_1_1_small.gin'
138
+
139
+ CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'
140
+ MIXTURE_OR_TASK_NAME = 'closed_book_qa'
141
+ MIXTURE_OR_TASK_MODULE = 'google_research.t5_closed_book_qa.t5_cbqa.tasks'
142
+ TASK_FEATURE_LENGTHS = {'inputs': 38, 'targets': 18}
143
+ ```
144
+
145
+ See
146
+ [`t5_1_1_small_cbqa_natural_questions.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions.gin)
147
+ for this example. Make sure that your Gin file is linked as a data dependency to
148
+ the T5X inference
149
+ [binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=74;rcl=398627055). If your
150
+ Gin file is not included, see the
151
+ [Advanced Topics section](#custom-t5x-binaries) at the end of this tutorial for
152
+ instructions to add it, or skip writing a Gin file and pass the above params as
153
+ flags when launching the inference job (see instructions in Step 4).
154
+
155
+ ## Step 4: Launch your experiment
156
+
157
+ To launch your experiment locally (for debugging only; larger checkpoints may
158
+ cause issues), run the following on commandline:
159
+
160
+ ```sh
161
+ INFER_OUTPUT_DIR="/tmp/model-infer/"
162
+ python -m t5x.infer_unfragmented \
163
+ --gin_file=t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions.gin \
164
+ --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
165
+ --alsologtostderr
166
+ ```
167
+
168
+ Note that multiple comma-separated paths can be passed to the `gin_search_paths`
169
+ flag, and these paths should contain all Gin files used or included in your
170
+ experiment.
171
+
172
+
173
+ ## Step 5: Monitor your experiment and parse results
174
+
175
+
176
+ After inference has completed, you can view predictions in the `jsonl` files in
177
+ the output dir. JSON data is written in chunks and combined at the end of the
178
+ inference run. Refer to [Sharding](#sharding) and
179
+ [Checkpointing](#checkpointing) sections for more details.
180
+
181
+ ## Next Steps
182
+
183
+ Now that you have successfully run inference on a model, here are some topics
184
+ you might want to explore next:
185
+
186
+ + [Fine-tuning a model.](finetune)
187
+ + [Evaluating a model.](eval)
188
+ + [Training a model from scratch.](pretrain)
189
+
190
+ We also touch upon a few advanced topics related to inference below that might
191
+ be useful, especially when customizing your inference job.
192
+
193
+ ## Advanced Topics
194
+
195
+ ### Dataset Sharding {#sharding .no-toc}
196
+
197
+ You can run inference in parallel across multiple TPU slices by setting the
198
+ `num_shards` flag when running using XManager. When `num_shards > 1`, the
199
+ dataset is interleaved among the shards and the predictions are combined in the
200
+ end; hence the order of examples in the data source and the predictions in the
201
+ output json files will not match (order is guaranteed to match for `num_shards =
202
+ 1` or the number of input file shards).
203
+
204
+ ### Dataset Checkpointing {#checkpointing .no-toc}
205
+
206
+ You can control dataset checkpointing frequency by overriding the
207
+ `infer.checkpoint_period` in
208
+ [runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin),
209
+ which is set to `100` by default. This means that the dataset is checkpointed
210
+ after running inferences on `checkpoint_period` batches (batches, not examples;
211
+ you can control batch size by overriding `utils.DatasetConfig.batch_size` in
212
+ [runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin), it
213
+ is set to `32` by default).
214
+
215
+ ### Changing Length and Decoding Strategy {#decoding-strategies .no-toc}
216
+
217
+ By default, T5X does inference using an arg-max decoding strategy, always
218
+ picking the most likely next token. To use random sampling instead, you may
219
+ change any of the following parameters in your gin config:
220
+
221
+ ```gin
222
+ decoding.temperature_sample:
223
+ temperature = 1.0
224
+ topk = 1
225
+ topp = 0.0
226
+ ```
227
+
228
+ You can also control the number of tokens which get generated by specifying:
229
+
230
+ ```gin
231
+ decoding.temperature_sample:
232
+ max_decode_steps = 50
233
+ ```
234
+
235
+ More detailed documentation on defining a decoding stategy can be found
236
+ [here](https://github.com/google-research/t5x/blob/main/docs/usage.md/decoding).
237
+
238
+
239
+ ### Defining a custom SeqIO Task/Mixture to run inference on {.no-toc}
240
+
241
+ Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md).