Upload 369 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- t5x-main/.github/workflows/build.yaml +39 -0
- t5x-main/CONTRIBUTING.md +1 -0
- t5x-main/LICENSE +202 -0
- t5x-main/README.md +525 -0
- t5x-main/docs/_static/t5x_theme.css +23 -0
- t5x-main/docs/_templates/autosummary/t5x_module.rst +23 -0
- t5x-main/docs/api_reference/index.rst +100 -0
- t5x-main/docs/api_reference/t5x.adafactor.rst +7 -0
- t5x-main/docs/api_reference/t5x.binary_search.rst +7 -0
- t5x-main/docs/api_reference/t5x.checkpoint_importer.rst +7 -0
- t5x-main/docs/api_reference/t5x.checkpoint_utils.rst +7 -0
- t5x-main/docs/api_reference/t5x.checkpoints.rst +7 -0
- t5x-main/docs/api_reference/t5x.config_utils.rst +7 -0
- t5x-main/docs/api_reference/t5x.decoding.rst +7 -0
- t5x-main/docs/api_reference/t5x.eval.rst +7 -0
- t5x-main/docs/api_reference/t5x.gin_utils.rst +7 -0
- t5x-main/docs/api_reference/t5x.infer.rst +7 -0
- t5x-main/docs/api_reference/t5x.interactive_model.rst +7 -0
- t5x-main/docs/api_reference/t5x.losses.rst +7 -0
- t5x-main/docs/api_reference/t5x.main.rst +7 -0
- t5x-main/docs/api_reference/t5x.metrics.rst +7 -0
- t5x-main/docs/api_reference/t5x.models.rst +7 -0
- t5x-main/docs/api_reference/t5x.optimizers.rst +7 -0
- t5x-main/docs/api_reference/t5x.partitioning.rst +7 -0
- t5x-main/docs/api_reference/t5x.state_utils.rst +7 -0
- t5x-main/docs/api_reference/t5x.test_utils.rst +7 -0
- t5x-main/docs/api_reference/t5x.train.rst +7 -0
- t5x-main/docs/api_reference/t5x.train_state.rst +7 -0
- t5x-main/docs/api_reference/t5x.trainer.rst +7 -0
- t5x-main/docs/api_reference/t5x.utils.rst +7 -0
- t5x-main/docs/conf.py +132 -0
- t5x-main/docs/conf_sphinx_patch.py +202 -0
- t5x-main/docs/contributions.md +64 -0
- t5x-main/docs/index.md +65 -0
- t5x-main/docs/index.rst +24 -0
- t5x-main/docs/models.md +318 -0
- t5x-main/docs/overview.md +2 -0
- t5x-main/docs/requirements.txt +8 -0
- t5x-main/docs/t5x.png +3 -0
- t5x-main/docs/tutorials.md +51 -0
- t5x-main/docs/usage/auxiliary.md +204 -0
- t5x-main/docs/usage/decoding.md +199 -0
- t5x-main/docs/usage/eval.md +226 -0
- t5x-main/docs/usage/finetune.md +286 -0
- t5x-main/docs/usage/gin.md +395 -0
- t5x-main/docs/usage/gpu-usage.md +87 -0
- t5x-main/docs/usage/index.rst +16 -0
- t5x-main/docs/usage/infer-files.md +217 -0
- t5x-main/docs/usage/infer-seqio.md +241 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
t5x-main/docs/t5x.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00001-of-00002 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.meta filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
t5x-main/t5x/testdata/test_t5_tiny.checkpoint_0 filter=lfs diff=lfs merge=lfs -text
|
t5x-main/.github/workflows/build.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: build
|
| 2 |
+
|
| 3 |
+
on: [push]
|
| 4 |
+
|
| 5 |
+
jobs:
|
| 6 |
+
build:
|
| 7 |
+
runs-on: ubuntu-latest
|
| 8 |
+
steps:
|
| 9 |
+
- uses: actions/checkout@v2
|
| 10 |
+
- name: Set up Python
|
| 11 |
+
uses: actions/setup-python@v4
|
| 12 |
+
with:
|
| 13 |
+
python-version: '3.10.x'
|
| 14 |
+
cache: 'pip'
|
| 15 |
+
cache-dependency-path: setup.py
|
| 16 |
+
- name: Install dependencies
|
| 17 |
+
run: |
|
| 18 |
+
pip install -e .[test] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
| 19 |
+
- name: Test with pytest
|
| 20 |
+
run: |
|
| 21 |
+
pytest
|
| 22 |
+
# The below step just reports the success or failure of tests as a "commit status".
|
| 23 |
+
# This is needed for copybara integration.
|
| 24 |
+
- name: Report success or failure as github status
|
| 25 |
+
if: always()
|
| 26 |
+
shell: bash
|
| 27 |
+
run: |
|
| 28 |
+
status="${{ job.status }}"
|
| 29 |
+
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
|
| 30 |
+
curl -sS --request POST \
|
| 31 |
+
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
|
| 32 |
+
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
|
| 33 |
+
--header 'content-type: application/json' \
|
| 34 |
+
--data '{
|
| 35 |
+
"state": "'$lowercase_status'",
|
| 36 |
+
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
|
| 37 |
+
"description": "'$status'",
|
| 38 |
+
"context": "github-actions/build"
|
| 39 |
+
}'
|
t5x-main/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
External contributions are not accepted, sorry!
|
t5x-main/LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright [yyyy] [name of copyright owner]
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
t5x-main/README.md
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# T5X
|
| 2 |
+
|
| 3 |
+
*Go to [T5X ReadTheDocs Documentation Page](https://t5x.readthedocs.io/).*
|
| 4 |
+
|
| 5 |
+
T5X is a modular, composable, research-friendly framework for high-performance,
|
| 6 |
+
configurable, self-service training, evaluation, and inference of sequence
|
| 7 |
+
models (starting with language) at many scales.
|
| 8 |
+
|
| 9 |
+
It is essentially a new and improved implementation of the
|
| 10 |
+
[T5 codebase](https://github.com/google-research/text-to-text-transfer-transformer)
|
| 11 |
+
(based on [Mesh TensorFlow](https://github.com/tensorflow/mesh)) in [JAX](https://github.com/google/jax) and [Flax](https://github.com/google/flax). To learn
|
| 12 |
+
more, see the [T5X Paper](https://arxiv.org/abs/2203.17189).
|
| 13 |
+
|
| 14 |
+
Below is a quick start guide for training models with TPUs on Google Cloud. For
|
| 15 |
+
additional tutorials and background, see the [complete documentation](docs/index.md).
|
| 16 |
+
|
| 17 |
+
## Quickstart (Recommended)
|
| 18 |
+
|
| 19 |
+
T5X can be run with [XManager](https://github.com/deepmind/xmanager) on
|
| 20 |
+
[Vertex AI](https://cloud.google.com/vertex-ai). Vertex AI is a platform for
|
| 21 |
+
training that creates TPU instances and runs code on the TPUs. Vertex AI will
|
| 22 |
+
also shut down the TPUs when the jobs terminate. This is signifcantly easier
|
| 23 |
+
than managing GCE VMs and TPU VM instances.
|
| 24 |
+
|
| 25 |
+
1. Follow the pre-requisites and directions to install [XManager](https://github.com/deepmind/xmanager).
|
| 26 |
+
|
| 27 |
+
2. Request TPU quota as required. GCP projects come with 8 cores by default,
|
| 28 |
+
which is enough to run one training experiment on a single TPU host. If you want
|
| 29 |
+
to run multi-host training or run multiple trials in parallel, you will need
|
| 30 |
+
more quota. Navigate to [Quotas](https://console.cloud.google.com/quotas).
|
| 31 |
+
|
| 32 |
+
The quota you want is:
|
| 33 |
+
|
| 34 |
+
* Service: `Vertex AI API`
|
| 35 |
+
* Dimensions (location): `us-central1`
|
| 36 |
+
* If you want to run single-host experiments:
|
| 37 |
+
* `Custom model training TPU V2 cores per region`
|
| 38 |
+
* `Custom model training TPU V3 cores per region`
|
| 39 |
+
* If you want to run multi-host experiments:
|
| 40 |
+
* `Custom model training TPU V2 pod cores per region`
|
| 41 |
+
* `Custom model training TPU V3 pod cores per region`
|
| 42 |
+
|
| 43 |
+
TIP: You won't be able to run single-host experiments with multi-host quota.
|
| 44 |
+
(i.e. you can't run `tpu_v2=8` using `TPU V2 pod`)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
3. Launch the xmanager script located at `t5x/scripts/xm_launch.py`.
|
| 48 |
+
|
| 49 |
+
As a running example, we use the WMT14 En-De translation which is described in
|
| 50 |
+
more detail in the Examples section below.
|
| 51 |
+
|
| 52 |
+
```sh
|
| 53 |
+
export GOOGLE_CLOUD_BUCKET_NAME=...
|
| 54 |
+
export TFDS_DATA_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/data
|
| 55 |
+
export MODEL_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/$(date +%Y%m%d)
|
| 56 |
+
|
| 57 |
+
# Pre-download dataset in multi-host experiments.
|
| 58 |
+
tfds build wmt_t2t_translate --data_dir=$TFDS_DATA_DIR
|
| 59 |
+
|
| 60 |
+
git clone https://github.com/google-research/t5x
|
| 61 |
+
cd ./t5x/
|
| 62 |
+
|
| 63 |
+
python3 ./t5x/scripts/xm_launch.py \
|
| 64 |
+
--gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin \
|
| 65 |
+
--model_dir=$MODEL_DIR \
|
| 66 |
+
--tfds_data_dir=$TFDS_DATA_DIR
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Check `gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/` for the output artifacts, which can
|
| 70 |
+
be read by TensorBoard.
|
| 71 |
+
|
| 72 |
+
## GPU Usage
|
| 73 |
+
Note: NVIDIA has released an updated version of this repository with H100 FP8 support and broad GPU performance improvements. Please visit the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository for more details and usage instructions.
|
| 74 |
+
|
| 75 |
+
T5X can be run easily on GPUs either in single-node configurations or multi-node configurations with a SLURM+pyxis cluster. Further instructions at [t5x/contrib/gpu](https://github.com/google-research/t5x/blob/main/t5x/contrib/gpu/README.md). The `t5x/contrib/gpu/scripts_gpu` folder contains example scripts for pretraining T5X on [The Pile](https://pile.eleuther.ai/) and for finetuning on SQuAD and MNLI. These scripts and associated `gin` configurations also contain additional GPU optimizations for better throughput. More examples and instructions can be found in the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository maintained by NVIDIA with H100 FP8 support and broad GPU performance improvements.
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
## Installation
|
| 79 |
+
|
| 80 |
+
Note that all the commands in this document should be run in the commandline of
|
| 81 |
+
the TPU VM instance unless otherwise stated.
|
| 82 |
+
|
| 83 |
+
1. Follow the
|
| 84 |
+
[instructions](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_the_google_cloud_sdk)
|
| 85 |
+
to set up a Google Cloud Platform (GCP) account and enable the Cloud TPU
|
| 86 |
+
API.
|
| 87 |
+
|
| 88 |
+
**Note:** T5X also works with GPU, please follow instructions in [t5x/contrib/gpu](https://github.com/google-research/t5x/blob/main/t5x/contrib/gpu/README.md) if you'd like to use GPU version.
|
| 89 |
+
|
| 90 |
+
2. Create a
|
| 91 |
+
[Cloud TPU VM instance](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms)
|
| 92 |
+
following
|
| 93 |
+
[this instruction](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#create-vm).
|
| 94 |
+
We recommend that you develop your workflow in a single v3-8 TPU (i.e.,
|
| 95 |
+
`--accelerator-type=v3-8`) and scale up to pod slices once the pipeline is
|
| 96 |
+
ready. In this README, we focus on using a single v3-8 TPU. See
|
| 97 |
+
[here](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) to
|
| 98 |
+
learn more about TPU architectures.
|
| 99 |
+
|
| 100 |
+
3. With Cloud TPU VMs, you ssh directly into the host machine of the TPU VM.
|
| 101 |
+
You can install packages, run your code run, etc. in the host machine. Once
|
| 102 |
+
the TPU instance is created, ssh into it with
|
| 103 |
+
|
| 104 |
+
```sh
|
| 105 |
+
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE}
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
where `TPU_NAME` and `ZONE` are the name and the zone used in step 2.
|
| 109 |
+
|
| 110 |
+
4. Install T5X and the dependencies.
|
| 111 |
+
|
| 112 |
+
```sh
|
| 113 |
+
git clone --branch=main https://github.com/google-research/t5x
|
| 114 |
+
cd t5x
|
| 115 |
+
|
| 116 |
+
python3 -m pip install -e '.[tpu]' -f \
|
| 117 |
+
https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
5. Create Google Cloud Storage (GCS) bucket to store the dataset and model
|
| 123 |
+
checkpoints. To create a GCS bucket, see these
|
| 124 |
+
[instructions](https://cloud.google.com/storage/docs/creating-buckets).
|
| 125 |
+
|
| 126 |
+
6. (optional) If you prefer working with Jupyter/Colab style environment
|
| 127 |
+
you can setup a custom Colab runtime by following steps from
|
| 128 |
+
[t5x/notebooks](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md).
|
| 129 |
+
|
| 130 |
+
## Example: English to German translation
|
| 131 |
+
|
| 132 |
+
As a running example, we use the WMT14 En-De translation. The raw dataset is
|
| 133 |
+
available in TensorFlow Datasets as
|
| 134 |
+
["wmt_t2t_translate"](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate).
|
| 135 |
+
|
| 136 |
+
T5 casts the translation task such as the following
|
| 137 |
+
|
| 138 |
+
```py
|
| 139 |
+
{'en': 'That is good.', 'de': 'Das ist gut.'}
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
to the form called "text-to-text":
|
| 143 |
+
|
| 144 |
+
```py
|
| 145 |
+
{'inputs': 'translate English to German: That is good.', 'targets': 'Das ist gut.'}
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
This formulation allows many different classes of language tasks to be expressed
|
| 149 |
+
in a uniform manner and a single encoder-decoder architecture can handle them
|
| 150 |
+
without any task-specific parameters. For more detail, refer to the [T5 paper
|
| 151 |
+
(Raffel et al. 2019)][t5_paper].
|
| 152 |
+
|
| 153 |
+
For a scalable data pipeline and an evaluation framework, we use
|
| 154 |
+
[`SeqIO`](https://github.com/google/seqio), which was factored out of the [T5
|
| 155 |
+
library][t5_github]. A `seqio.Task` packages together the raw dataset, vocabulary,
|
| 156 |
+
preprocessing such as tokenization and evaluation metrics such as
|
| 157 |
+
[BLEU](https://aclanthology.org/P02-1040.pdf) and provides a
|
| 158 |
+
[`tf.data`](https://www.tensorflow.org/guide/data) instance.
|
| 159 |
+
|
| 160 |
+
[The T5 library][t5_github] provides a number of `seqio.Task`s that were used in the
|
| 161 |
+
[T5 paper][t5_paper]. In this example, we use [wmt_t2t_ende_v003](https://github.com/google-research/text-to-text-transfer-transformer/blob/d81c0bab2a41b4d5dfbe4971de32f7d67df65f31/t5/data/tasks.py#L212).
|
| 162 |
+
|
| 163 |
+
Before training or fine-tuning you need to download ["wmt_t2t_translate"]
|
| 164 |
+
(https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate) dataset first.
|
| 165 |
+
|
| 166 |
+
```sh
|
| 167 |
+
# Data dir to save the processed dataset in "gs://data_dir" format.
|
| 168 |
+
TFDS_DATA_DIR="..."
|
| 169 |
+
|
| 170 |
+
# Make sure that dataset package is up-to-date.
|
| 171 |
+
python3 -m pip install --upgrade tfds-nightly
|
| 172 |
+
|
| 173 |
+
# Pre-download dataset.
|
| 174 |
+
tfds build wmt_t2t_translate ${TFDS_DATA_DIR}
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
### Training
|
| 178 |
+
|
| 179 |
+
To run a training job, we use the `t5x/train.py` script.
|
| 180 |
+
|
| 181 |
+
```sh
|
| 182 |
+
# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
|
| 183 |
+
MODEL_DIR="..."
|
| 184 |
+
T5X_DIR="..." # directory where the T5X repo is cloned.
|
| 185 |
+
TFDS_DATA_DIR="..."
|
| 186 |
+
|
| 187 |
+
python3 ${T5X_DIR}/t5x/train.py \
|
| 188 |
+
--gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
|
| 189 |
+
--gin.MODEL_DIR=\"${MODEL_DIR}\" \
|
| 190 |
+
--tfds_data_dir=${TFDS_DATA_DIR}
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
The configuration for this training run is defined in the Gin file
|
| 194 |
+
[base_wmt_from_scratch.gin](t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin).
|
| 195 |
+
[Gin-config](https://github.com/google/gin-config) is a library to handle
|
| 196 |
+
configurations based on dependency injection. Among many benefits, Gin allows
|
| 197 |
+
users to pass custom components such as a custom model to the T5X library
|
| 198 |
+
without having to modify the core library. The [custom
|
| 199 |
+
components](#custom-components) section shows how this is done.
|
| 200 |
+
|
| 201 |
+
While the core library is independent of Gin, it is central to the examples we
|
| 202 |
+
provide. Therefore, we provide a short [introduction][gin-primer] to Gin in the
|
| 203 |
+
context of T5X. All the configurations are written to a file "config.gin" in
|
| 204 |
+
`MODEL_DIR`. This makes debugging as well as reproducing the experiment much
|
| 205 |
+
easier.
|
| 206 |
+
|
| 207 |
+
In addition to the `config.json`, `model-info.txt` file summarizes the model
|
| 208 |
+
parameters (shape, names of the axes, partitioning info) as well as the
|
| 209 |
+
optimizer states.
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
#### TensorBoard
|
| 214 |
+
|
| 215 |
+
To monitor the training in [TensorBoard](https://www.tensorflow.org/tensorboard), it is much easier (due to
|
| 216 |
+
authentification issues) to launch the TensorBoard on your own machine and _not_ in
|
| 217 |
+
the TPU VM. So in the commandline where you ssh'ed into the TPU VM, launch the
|
| 218 |
+
TensorBoard with the `logdir` pointing to the `MODEL_DIR`.
|
| 219 |
+
|
| 220 |
+
```sh
|
| 221 |
+
# NB: run this on your machine not TPU VM!
|
| 222 |
+
MODEL_DIR="..." # Copy from the TPU VM.
|
| 223 |
+
tensorboard --logdir=${MODEL_DIR}
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
Or you can launch the TensorBoard inside a Colab. In a Colab cell, run
|
| 227 |
+
|
| 228 |
+
```python
|
| 229 |
+
from google.colab import auth
|
| 230 |
+
auth.authenticate_user()
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
to authorize the Colab to access the GCS bucket and launch the TensorBoard.
|
| 234 |
+
|
| 235 |
+
```python
|
| 236 |
+
%load_ext tensorboard
|
| 237 |
+
model_dir = "..." # Copy from the TPU VM.
|
| 238 |
+
%tensorboard --logdir=model_dir
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
### Fine-tuning
|
| 243 |
+
|
| 244 |
+
We can leverage the benefits of self-supervised pre-training by initializing
|
| 245 |
+
from one of our pre-trained models. Here we use the T5.1.1 Base checkpoint.
|
| 246 |
+
|
| 247 |
+
```sh
|
| 248 |
+
# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
|
| 249 |
+
MODEL_DIR="..."
|
| 250 |
+
|
| 251 |
+
# Data dir to save the processed dataset in "gs://data_dir" format.
|
| 252 |
+
TFDS_DATA_DIR="..."
|
| 253 |
+
T5X_DIR="..." # directory where the T5X repo is cloned.
|
| 254 |
+
|
| 255 |
+
python3 ${T5X_DIR}/t5x/train.py \
|
| 256 |
+
--gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_finetune.gin" \
|
| 257 |
+
--gin.MODEL_DIR=\"${MODEL_DIR}\" \
|
| 258 |
+
--tfds_data_dir=${TFDS_DATA_DIR}
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
**Note:** when supplying a string, dict, list, tuple value, or a bash variable
|
| 262 |
+
via a flag, you must put it in quotes. In the case of strings, it requires
|
| 263 |
+
escaped quotes (`\"<string>\"`). For example:
|
| 264 |
+
`--gin.utils.DatasetConfig.split=\"validation\"` or
|
| 265 |
+
`--gin.MODEL_DIR=\"${MODEL_DIR}\"`.
|
| 266 |
+
|
| 267 |
+
Gin makes it easy to change a number of configurations. For example, you can
|
| 268 |
+
change the `partitioning.PjitPartitioner.num_partitions` (overriding
|
| 269 |
+
the value in
|
| 270 |
+
[base_wmt_from_scratch.gin](t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin))
|
| 271 |
+
to chanage the parallelism strategy and pass it as a commandline arg.
|
| 272 |
+
|
| 273 |
+
```sh
|
| 274 |
+
--gin.partitioning.PjitPartitioner.num_partitions=8
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
### Evaluation
|
| 278 |
+
|
| 279 |
+
To run the offline (i.e. without training) evaluation, you can use `t5x/eval.py`
|
| 280 |
+
script.
|
| 281 |
+
|
| 282 |
+
```sh
|
| 283 |
+
EVAL_OUTPUT_DIR="..." # directory to write eval output
|
| 284 |
+
T5X_DIR="..." # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
|
| 285 |
+
TFDS_DATA_DIR="..."
|
| 286 |
+
CHECKPOINT_PATH="..."
|
| 287 |
+
|
| 288 |
+
python3 ${T5X_DIR}/t5x/eval.py \
|
| 289 |
+
--gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_eval.gin" \
|
| 290 |
+
--gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
|
| 291 |
+
--gin.EVAL_OUTPUT_DIR=\"${EVAL_OUTPUT_DIR}\" \
|
| 292 |
+
--tfds_data_dir=${TFDS_DATA_DIR}
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
### Inference
|
| 297 |
+
|
| 298 |
+
To run inference, you can use `t5x/infer.py` script. Here we use the same
|
| 299 |
+
`seqio.Task`, but for inference we do not use the targets features other than
|
| 300 |
+
logging them alongside the prediction in a JSON file.
|
| 301 |
+
|
| 302 |
+
```sh
|
| 303 |
+
INFER_OUTPUT_DIR="..." # directory to write infer output
|
| 304 |
+
T5X_DIR="..." # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
|
| 305 |
+
TFDS_DATA_DIR="..."
|
| 306 |
+
CHECKPOINT_PATH="..."
|
| 307 |
+
|
| 308 |
+
python3 ${T5X_DIR}/t5x/infer.py \
|
| 309 |
+
--gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_infer.gin" \
|
| 310 |
+
--gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
|
| 311 |
+
--gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
|
| 312 |
+
--tfds_data_dir=${TFDS_DATA_DIR}
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
### Exporting as TensorFlow Saved Model
|
| 316 |
+
|
| 317 |
+
Pretrained model can be exported as TensorFlow Saved Model, and deployed
|
| 318 |
+
to Vertex AI Prediction service using [Optimized TensorFlow Runtime]
|
| 319 |
+
(https://cloud.google.com/vertex-ai/docs/predictions/optimized-tensorflow-runtime).
|
| 320 |
+
Please note that exported model won't work on OSS based
|
| 321 |
+
[TensorFlow Model Server](https://github.com/tensorflow/serving).
|
| 322 |
+
|
| 323 |
+
```sh
|
| 324 |
+
T5X_DIR="..." # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
|
| 325 |
+
CHECKPOINT_PATH="..."
|
| 326 |
+
|
| 327 |
+
BATCH_SIZE=None
|
| 328 |
+
BEAM_SIZE=1
|
| 329 |
+
|
| 330 |
+
# Use 'bfloat16' if you plan to run exported model on NVIDIA A100 or newer GPUs,
|
| 331 |
+
# for other GPUs use 'float32'.
|
| 332 |
+
ACTIVATION_DTYPE=bfloat16
|
| 333 |
+
|
| 334 |
+
# Version numbers must be numeric. We generate one based on datetime.
|
| 335 |
+
VERSION=$(date +%Y%m%d%H%M%S)
|
| 336 |
+
|
| 337 |
+
NAME=t5x_base_${ACTIVATION_DTYPE} # Model name.
|
| 338 |
+
|
| 339 |
+
# Path to export model to. Note that export script is going to add _cpu suffix
|
| 340 |
+
# after model name.
|
| 341 |
+
OUTPUT=${CHECKPOINT_PATH}/saved_model.${NAME}/${VERSION}
|
| 342 |
+
|
| 343 |
+
declare -a ARGS=(
|
| 344 |
+
--gin_file=t5x/examples/t5/t5_1_1/base.gin
|
| 345 |
+
--gin_file=t5x/t5x/configs/runs/export.gin
|
| 346 |
+
--gin.TASK_FEATURE_LENGTHS="{'inputs': 256, 'targets': 256}"
|
| 347 |
+
--gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\"
|
| 348 |
+
--gin.MODEL_NAME=\"/ml/${USER}/t5x_base\"
|
| 349 |
+
--gin.MODEL_OUTPUT_DIR=\"${OUTPUT}\"
|
| 350 |
+
--gin.BEAM_SIZE=${BEAM_SIZE}
|
| 351 |
+
--gin.BATCH_SIZE=${BATCH_SIZE}
|
| 352 |
+
--gin.export_lib.save.partitioner=None
|
| 353 |
+
--gin.export_lib.save.warmup_examples="['hello world']"
|
| 354 |
+
--gin.export_lib.ExportableModule.use_batch_function=False
|
| 355 |
+
--gin.export_lib.ExportableModule.use_gpu=False
|
| 356 |
+
--gin.export_lib.ExportableModule.jit_compile=False
|
| 357 |
+
--gin.ACTIVATION_DTYPE=\"${ACTIVATION_DTYPE}\"
|
| 358 |
+
--gin.network.T5Config.dtype=\"${ACTIVATION_DTYPE}\"
|
| 359 |
+
--gin.utils.RestoreCheckpointConfig.dtype=\"${ACTIVATION_DTYPE}\"
|
| 360 |
+
--gin.DROPOUT_RATE=0.0
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
(python3 ${T5X_DIR}/t5x/export.py "${ARGS[@]}")
|
| 364 |
+
```
|
| 365 |
+
|
| 366 |
+
For detailed arguments definition refer to [export.gin]
|
| 367 |
+
(t5x/configs/runs/export.gin).
|
| 368 |
+
|
| 369 |
+
You can run XL and smaller models on NVIDIA A100 40GB, and XXL models on
|
| 370 |
+
NVIDIA A100 80GB.
|
| 371 |
+
|
| 372 |
+
## Custom components
|
| 373 |
+
|
| 374 |
+
[The translation example](#example-english-to-german-translation) uses the
|
| 375 |
+
encoder-decoder model that T5X provides as well as the dataset from the T5
|
| 376 |
+
library. This section shows how you can use your own dataset and a model and
|
| 377 |
+
pass via Gin.
|
| 378 |
+
|
| 379 |
+
### Example: custom dataset in a user directory
|
| 380 |
+
|
| 381 |
+
For this example, we have the following directory structure with
|
| 382 |
+
`${HOME}/dir1/user_dir` representing a user directory with custom components.
|
| 383 |
+
|
| 384 |
+
```
|
| 385 |
+
${HOME}
|
| 386 |
+
└── dir1
|
| 387 |
+
└── user_dir
|
| 388 |
+
├── t5_1_1_base_de_en.gin
|
| 389 |
+
└── tasks.py
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
As an example, let's define a new dataset. Here we use the same Translation
|
| 393 |
+
dataset but we define the translation task in the opposite direction, i.e.,
|
| 394 |
+
German to English intead of English to German. We define this task in `tasks.py`
|
| 395 |
+
|
| 396 |
+
```py
|
| 397 |
+
# ${HOME}/dir1/user_dir/tasks.py
|
| 398 |
+
|
| 399 |
+
import functools
|
| 400 |
+
import seqio
|
| 401 |
+
import tensorflow_datasets as tfds
|
| 402 |
+
from t5.evaluation import metrics
|
| 403 |
+
from t5.data import preprocessors
|
| 404 |
+
|
| 405 |
+
vocabulary = seqio.SentencePieceVocabulary(
|
| 406 |
+
'gs://t5-data/vocabs/cc_all.32000/sentencepiece.model', extra_ids=100)
|
| 407 |
+
output_features = {
|
| 408 |
+
'inputs': seqio.Feature(vocabulary=vocabulary),
|
| 409 |
+
'targets': seqio.Feature(vocabulary=vocabulary)
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
seqio.TaskRegistry.add(
|
| 413 |
+
'wmt_t2t_de_en_v003',
|
| 414 |
+
source=seqio.TfdsDataSource(tfds_name='wmt_t2t_translate/de-en:1.0.0'),
|
| 415 |
+
preprocessors=[
|
| 416 |
+
functools.partial(
|
| 417 |
+
preprocessors.translate,
|
| 418 |
+
source_language='de', target_language='en'),
|
| 419 |
+
seqio.preprocessors.tokenize,
|
| 420 |
+
seqio.CacheDatasetPlaceholder(),
|
| 421 |
+
seqio.preprocessors.append_eos_after_trim,
|
| 422 |
+
],
|
| 423 |
+
metric_fns=[metrics.bleu],
|
| 424 |
+
output_features=output_features)
|
| 425 |
+
```
|
| 426 |
+
|
| 427 |
+
In the Gin file, most of the settings are equivalent to those used in the
|
| 428 |
+
[En->De example](#example-english-to-german-translation). So we include the Gin
|
| 429 |
+
file from that example. To use "wmt_t2t_de_en_v003" task we just defined, we
|
| 430 |
+
need to import the task module "tasks.py". Note that we use a relative path
|
| 431 |
+
defined with respect to the user directory. This will be specified as a
|
| 432 |
+
flag.
|
| 433 |
+
|
| 434 |
+
```py
|
| 435 |
+
# ${HOME}/dir1/user_dir/t5_1_1_base_de_en.gin
|
| 436 |
+
from __gin__ import dynamic_registration
|
| 437 |
+
import tasks # This imports the task defined in dir1/user_dir/tasks.py.
|
| 438 |
+
|
| 439 |
+
include "t5x-tmp/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin"
|
| 440 |
+
MIXTURE_OR_TASK_NAME = "wmt_t2t_de_en_v003"
|
| 441 |
+
```
|
| 442 |
+
|
| 443 |
+
Finally, we launch training passing the user directory as a flag
|
| 444 |
+
`gin_search_paths` such that the Gin file and python modules can be specified
|
| 445 |
+
with relative paths.
|
| 446 |
+
|
| 447 |
+
```sh
|
| 448 |
+
PROJECT_DIR=${HOME}"/dir1/user_dir"
|
| 449 |
+
T5X_DIR="..." # directory where the t5x is cloned.
|
| 450 |
+
TFDS_DATA_DIR="..."
|
| 451 |
+
MODEL_DIR="..."
|
| 452 |
+
export PYTHONPATH=${PROJECT_DIR}
|
| 453 |
+
|
| 454 |
+
python3 ${T5X_DIR}/t5x/train.py \
|
| 455 |
+
--gin_search_paths=${PROJECT_DIR} \
|
| 456 |
+
--gin_file="t5_1_1_base_de_en.gin" \
|
| 457 |
+
--gin.MODEL_DIR=\"${MODEL_DIR}\" \
|
| 458 |
+
--tfds_data_dir=${TFDS_DATA_DIR}
|
| 459 |
+
```
|
| 460 |
+
|
| 461 |
+
## Checkpoints
|
| 462 |
+
|
| 463 |
+
### Native Checkpoints
|
| 464 |
+
|
| 465 |
+
We have released the checkpoints of many of the original T5 models and their
|
| 466 |
+
variants a native T5X format for maximal efficiency.
|
| 467 |
+
See the [complete list](https://github.com/google-research/t5x/blob/main/docs/models.md) including the
|
| 468 |
+
matching Gin configuration files.
|
| 469 |
+
|
| 470 |
+
These are converted from the public [Mesh TensorFlow
|
| 471 |
+
checkpoints](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511)
|
| 472 |
+
.
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
### Compatibility with the Mesh TensorFlow checkpoints
|
| 476 |
+
The Mesh TensorFlow checkpoints trained using the [T5 library][t5_github] can be
|
| 477 |
+
directly loaded into T5X. For example, we can rerun the fine-tuning example
|
| 478 |
+
initializing from the MTF checkpoint by changing the `INIT_CHECKPOINT` Gin
|
| 479 |
+
macro.
|
| 480 |
+
|
| 481 |
+
```sh
|
| 482 |
+
# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
|
| 483 |
+
MODEL_DIR="..."
|
| 484 |
+
|
| 485 |
+
# Data dir to save the processed dataset in "gs://data_dir" format.
|
| 486 |
+
TFDS_DATA_DIR="..."
|
| 487 |
+
T5X_DIR="..." # directory where the T5X repo is cloned.
|
| 488 |
+
|
| 489 |
+
python3 ${T5X_DIR}/t5x/train.py \
|
| 490 |
+
--gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt19_ende_train.gin" \
|
| 491 |
+
--gin.MODEL_DIR=\"${MODEL_DIR}\" \
|
| 492 |
+
--gin.MIXTURE_OR_TASK_NAME=\"wmt_t2t_ende_v003\" \
|
| 493 |
+
--gin.INIT_CHECKPOINT=\"gs://t5-data/pretrained_models/t5.1.1.base/model.ckpt-1000000\" \
|
| 494 |
+
--tfds_data_dir=${TFDS_DATA_DIR}
|
| 495 |
+
```
|
| 496 |
+
|
| 497 |
+
Note that restoring directly from the Mesh TensorFlow checkpoints can be
|
| 498 |
+
inefficient if heavy model parallelism is used for large models. This is
|
| 499 |
+
because each host loads the entire copy of the model first and then keep only
|
| 500 |
+
the relevant slices dictated by the model parallelism specification. If you have
|
| 501 |
+
Mesh TensorFlow checkpoints that you run often, we recommend converting the
|
| 502 |
+
checkpoints to T5X native format using the
|
| 503 |
+
[convert_tf_checkpoint script](t5x/scripts/convert_tf_checkpoint.py).
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
## Citing T5X
|
| 507 |
+
Please use the following bibtex entry to cite T5X.
|
| 508 |
+
|
| 509 |
+
```
|
| 510 |
+
@article{roberts2022t5x,
|
| 511 |
+
url = {https://arxiv.org/abs/2203.17189},
|
| 512 |
+
author = {Roberts, Adam and Chung, Hyung Won and Levskaya, Anselm and Mishra, Gaurav and Bradbury, James and Andor, Daniel and Narang, Sharan and Lester, Brian and Gaffney, Colin and Mohiuddin, Afroz and Hawthorne, Curtis and Lewkowycz, Aitor and Salcianu, Alex and van Zee, Marc and Austin, Jacob and Goodman, Sebastian and Soares, Livio Baldini and Hu, Haitang and Tsvyashchenko, Sasha and Chowdhery, Aakanksha and Bastings, Jasmijn and Bulian, Jannis and Garcia, Xavier and Ni, Jianmo and Chen, Andrew and Kenealy, Kathleen and Clark, Jonathan H. and Lee, Stephan and Garrette, Dan and Lee-Thorp, James and Raffel, Colin and Shazeer, Noam and Ritter, Marvin and Bosma, Maarten and Passos, Alexandre and Maitin-Shepard, Jeremy and Fiedel, Noah and Omernick, Mark and Saeta, Brennan and Sepassi, Ryan and Spiridonov, Alexander and Newlan, Joshua and Gesmundo, Andrea},
|
| 513 |
+
title = {Scaling Up Models and Data with $\texttt{t5x}$ and $\texttt{seqio}$},
|
| 514 |
+
journal={arXiv preprint arXiv:2203.17189},
|
| 515 |
+
year = {2022},
|
| 516 |
+
}
|
| 517 |
+
```
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
## Note
|
| 521 |
+
This is not an officially supported Google product
|
| 522 |
+
|
| 523 |
+
[t5_paper]: https://arxiv.org/abs/1910.10683
|
| 524 |
+
[t5_github]: https://github.com/google-research/text-to-text-transfer-transformer
|
| 525 |
+
[gin-primer]: docs/usage/gin.md
|
t5x-main/docs/_static/t5x_theme.css
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@import url("theme.css");
|
| 2 |
+
|
| 3 |
+
.wy-nav-content {
|
| 4 |
+
max-width: 1290px;
|
| 5 |
+
}
|
| 6 |
+
|
| 7 |
+
.rst-content table.docutils {
|
| 8 |
+
width: 100%;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
.rst-content table.docutils td {
|
| 12 |
+
vertical-align: top;
|
| 13 |
+
padding: 0;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
.rst-content table.docutils td p {
|
| 17 |
+
padding: 8px;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
.rst-content div[class^=highlight] {
|
| 21 |
+
border: 0;
|
| 22 |
+
margin: 0;
|
| 23 |
+
}
|
t5x-main/docs/_templates/autosummary/t5x_module.rst
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{{ fullname | escape | underline}}
|
| 2 |
+
|
| 3 |
+
.. currentmodule:: {{ module }}
|
| 4 |
+
|
| 5 |
+
.. autoclass:: {{ objname }}
|
| 6 |
+
:exclude-members:
|
| 7 |
+
|
| 8 |
+
{% block methods %}
|
| 9 |
+
|
| 10 |
+
.. automethod:: __call__
|
| 11 |
+
|
| 12 |
+
{% if methods %}
|
| 13 |
+
.. rubric:: Methods
|
| 14 |
+
|
| 15 |
+
.. autosummary::
|
| 16 |
+
|
| 17 |
+
{% for item in methods %}
|
| 18 |
+
{%- if item not in inherited_members and item not in annotations and not item in ['__init__'] %}
|
| 19 |
+
~{{ name }}.{{ item }}
|
| 20 |
+
{%- endif %}
|
| 21 |
+
{%- endfor %}
|
| 22 |
+
{% endif %}
|
| 23 |
+
{% endblock %}
|
t5x-main/docs/api_reference/index.rst
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
API Reference
|
| 2 |
+
=============
|
| 3 |
+
|
| 4 |
+
Binaries
|
| 5 |
+
--------
|
| 6 |
+
|
| 7 |
+
.. toctree::
|
| 8 |
+
:maxdepth: 3
|
| 9 |
+
|
| 10 |
+
t5x.train
|
| 11 |
+
t5x.infer
|
| 12 |
+
t5x.eval
|
| 13 |
+
t5x.main
|
| 14 |
+
|
| 15 |
+
Training
|
| 16 |
+
---------
|
| 17 |
+
|
| 18 |
+
.. toctree::
|
| 19 |
+
:maxdepth: 3
|
| 20 |
+
|
| 21 |
+
t5x.trainer
|
| 22 |
+
t5x.optimizers
|
| 23 |
+
t5x.interactive_model
|
| 24 |
+
t5x.train_state
|
| 25 |
+
t5x.state_utils
|
| 26 |
+
t5x.losses
|
| 27 |
+
t5x.metrics
|
| 28 |
+
t5x.utils
|
| 29 |
+
t5x.adafactor
|
| 30 |
+
|
| 31 |
+
Inference
|
| 32 |
+
---------
|
| 33 |
+
|
| 34 |
+
.. toctree::
|
| 35 |
+
:maxdepth: 3
|
| 36 |
+
|
| 37 |
+
t5x.decoding
|
| 38 |
+
|
| 39 |
+
Models
|
| 40 |
+
------
|
| 41 |
+
|
| 42 |
+
.. toctree::
|
| 43 |
+
:maxdepth: 3
|
| 44 |
+
|
| 45 |
+
t5x.models
|
| 46 |
+
|
| 47 |
+
Checkpointing
|
| 48 |
+
-------------
|
| 49 |
+
|
| 50 |
+
.. toctree::
|
| 51 |
+
:maxdepth: 3
|
| 52 |
+
|
| 53 |
+
t5x.checkpoints
|
| 54 |
+
t5x.checkpoint_utils
|
| 55 |
+
t5x.checkpoint_importer
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
Paritioning
|
| 59 |
+
-----------
|
| 60 |
+
|
| 61 |
+
.. toctree::
|
| 62 |
+
:maxdepth: 3
|
| 63 |
+
|
| 64 |
+
t5x.partitioning
|
| 65 |
+
|
| 66 |
+
Config
|
| 67 |
+
------
|
| 68 |
+
|
| 69 |
+
.. toctree::
|
| 70 |
+
:maxdepth: 3
|
| 71 |
+
|
| 72 |
+
t5x.config_utils
|
| 73 |
+
t5x.gin_utils
|
| 74 |
+
|
| 75 |
+
Utils
|
| 76 |
+
-----
|
| 77 |
+
|
| 78 |
+
.. toctree::
|
| 79 |
+
:maxdepth: 3
|
| 80 |
+
|
| 81 |
+
t5x.test_utils
|
| 82 |
+
t5x.binary_search
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
t5x-main/docs/api_reference/t5x.adafactor.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.adafactor package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.adafactor
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.adafactor
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.binary_search.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.binary_search package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.binary_search
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.binary_search
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.checkpoint_importer.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.checkpoint_importer package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.checkpoint_importer
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.checkpoint_importer
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.checkpoint_utils.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.checkpoint_utils package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.checkpoint_utils
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.checkpoint_utils
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.checkpoints.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.checkpoints package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.checkpoints
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.checkpoints
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.config_utils.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.config_utils package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.config_utils
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.config_utils
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.decoding.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.decoding package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.decoding
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.decoding
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.eval.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.eval binary
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.eval
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.eval
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.gin_utils.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.gin_utils package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.gin_utils
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.gin_utils
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.infer.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.infer binary
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.infer
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.infer
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.interactive_model.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.interactive_model package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.interactive_model
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.interactive_model
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.losses.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.losses package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.losses
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.losses
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.main.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.main binary
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.main
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.main
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.metrics.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.metrics package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.metrics
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.metrics
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.models.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.models package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.models
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.models
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.optimizers.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.optimizers package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.optimizers
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.optimizers
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.partitioning.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.partitioning package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.partitioning
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.partitioning
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.state_utils.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.state_utils package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.state_utils
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.state_utils
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.test_utils.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.test_utils package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.test_utils
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.test_utils
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.train.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.train binary
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.train
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.train
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.train_state.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.train_state package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.train_state
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.train_state
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.trainer.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.trainer package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.trainer
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.trainer
|
| 7 |
+
:members:
|
t5x-main/docs/api_reference/t5x.utils.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t5x.utils package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. currentmodule:: t5x.utils
|
| 5 |
+
|
| 6 |
+
.. automodule:: t5x.utils
|
| 7 |
+
:members:
|
t5x-main/docs/conf.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The T5X Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Configuration file for the Sphinx documentation builder.
|
| 16 |
+
|
| 17 |
+
This file only contains a selection of the most common options. For a full
|
| 18 |
+
list see the documentation:
|
| 19 |
+
https://www.sphinx-doc.org/en/master/usage/configuration.html
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# pylint:disable=all
|
| 23 |
+
# -- Path setup --------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
| 26 |
+
# add these directories to sys.path here. If the directory is relative to the
|
| 27 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
| 28 |
+
#
|
| 29 |
+
import os
|
| 30 |
+
import sys
|
| 31 |
+
|
| 32 |
+
sys.path.insert(0, os.path.abspath('..'))
|
| 33 |
+
|
| 34 |
+
# patch sphinx
|
| 35 |
+
import docs.conf_sphinx_patch
|
| 36 |
+
|
| 37 |
+
# -- Project information -----------------------------------------------------
|
| 38 |
+
|
| 39 |
+
project = 'T5X'
|
| 40 |
+
copyright = '2023, The T5X authors' # pylint: disable=redefined-builtin
|
| 41 |
+
author = 'The T5X authors'
|
| 42 |
+
|
| 43 |
+
# -- General configuration ---------------------------------------------------
|
| 44 |
+
|
| 45 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
| 46 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
| 47 |
+
# ones.
|
| 48 |
+
extensions = [
|
| 49 |
+
'sphinx.ext.autodoc',
|
| 50 |
+
'sphinx.ext.autosummary',
|
| 51 |
+
'sphinx.ext.autosectionlabel',
|
| 52 |
+
'sphinx.ext.doctest',
|
| 53 |
+
'sphinx.ext.intersphinx',
|
| 54 |
+
'sphinx.ext.mathjax',
|
| 55 |
+
'sphinx.ext.napoleon',
|
| 56 |
+
'sphinx.ext.viewcode',
|
| 57 |
+
'myst_nb',
|
| 58 |
+
'sphinx_design',
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
# The suffix(es) of source filenames.
|
| 62 |
+
# You can specify multiple suffix as a list of string:
|
| 63 |
+
#
|
| 64 |
+
source_suffix = ['.rst', '.ipynb', '.md']
|
| 65 |
+
|
| 66 |
+
autosummary_generate = True
|
| 67 |
+
|
| 68 |
+
master_doc = 'index'
|
| 69 |
+
|
| 70 |
+
autodoc_typehints = 'none'
|
| 71 |
+
|
| 72 |
+
# Add any paths that contain templates here, relative to this directory.
|
| 73 |
+
templates_path = ['_templates']
|
| 74 |
+
|
| 75 |
+
# List of patterns, relative to source directory, that match files and
|
| 76 |
+
# directories to ignore when looking for source files.
|
| 77 |
+
# This pattern also affects html_static_path and html_extra_path.
|
| 78 |
+
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
| 79 |
+
|
| 80 |
+
# -- Options for HTML output -------------------------------------------------
|
| 81 |
+
|
| 82 |
+
# The theme to use for HTML and HTML Help pages. See the documentation for
|
| 83 |
+
# a list of builtin themes.
|
| 84 |
+
#
|
| 85 |
+
# html_theme = 'pydata_sphinx_theme'
|
| 86 |
+
html_theme = 'sphinx_book_theme'
|
| 87 |
+
html_css_files = ['css/t5x_theme.css']
|
| 88 |
+
|
| 89 |
+
# The name of an image file (relative to this directory) to place at the top
|
| 90 |
+
# of the sidebar.
|
| 91 |
+
html_logo = './t5x.png'
|
| 92 |
+
html_favicon = './t5x.png'
|
| 93 |
+
|
| 94 |
+
# title of the website
|
| 95 |
+
html_title = ''
|
| 96 |
+
|
| 97 |
+
# Add any paths that contain custom static files (such as style sheets) here,
|
| 98 |
+
# relative to this directory. They are copied after the builtin static files,
|
| 99 |
+
# so a file named 'default.css' will overwrite the builtin 'default.css'.
|
| 100 |
+
html_static_path = ['_static']
|
| 101 |
+
|
| 102 |
+
html_theme_options = {
|
| 103 |
+
'repository_url': 'https://github.com/google-research/t5x',
|
| 104 |
+
'use_repository_button': True, # add a 'link to repository' button
|
| 105 |
+
'use_issues_button': False, # add an 'Open an Issue' button
|
| 106 |
+
'path_to_docs': (
|
| 107 |
+
'docs'
|
| 108 |
+
), # used to compute the path to launch notebooks in colab
|
| 109 |
+
'launch_buttons': {
|
| 110 |
+
'colab_url': 'https://colab.research.google.com/',
|
| 111 |
+
},
|
| 112 |
+
'prev_next_buttons_location': None,
|
| 113 |
+
'show_navbar_depth': 1,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
# -- Options for myst ----------------------------------------------
|
| 117 |
+
# uncomment line below to avoid running notebooks during development
|
| 118 |
+
# nb_execution_mode = 'off'
|
| 119 |
+
# Notebook cell execution timeout; defaults to 30.
|
| 120 |
+
nb_execution_timeout = 100
|
| 121 |
+
# List of patterns, relative to source directory, that match notebook
|
| 122 |
+
# files that will not be executed.
|
| 123 |
+
myst_enable_extensions = ['dollarmath']
|
| 124 |
+
# raise exceptions on execution so CI can catch errors
|
| 125 |
+
nb_execution_allow_errors = False
|
| 126 |
+
nb_execution_raise_on_error = True
|
| 127 |
+
|
| 128 |
+
# -- Extension configuration -------------------------------------------------
|
| 129 |
+
|
| 130 |
+
# Tell sphinx-autodoc-typehints to generate stub parameter annotations including
|
| 131 |
+
# types, even if the parameters aren't explicitly documented.
|
| 132 |
+
always_document_param_types = True
|
t5x-main/docs/conf_sphinx_patch.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The T5X Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Patch Sphinx to improve documentation aesthetics."""
|
| 16 |
+
|
| 17 |
+
# TODO(cgarciae): Send a PR to sphinx to upstream this fix.
|
| 18 |
+
# Issue: https://github.com/google/flax/issues/2196
|
| 19 |
+
# This patch is needed to make autosummary provide the "annotations"
|
| 20 |
+
# variable so we can exclude function attributes from the methods list
|
| 21 |
+
# in flax_module.rst. The patch as such only adds this single line:
|
| 22 |
+
#
|
| 23 |
+
# ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys())'
|
| 24 |
+
#
|
| 25 |
+
# We should consider sending a PR to sphinx so we can get rid of this.
|
| 26 |
+
# Original source:
|
| 27 |
+
# https://github.com/sphinx-doc/sphinx/blob/0aedcc9a916daa92d477226da67d33ce1831822e/sphinx/ext/autosummary/generate.py#L211-L351
|
| 28 |
+
from typing import Any, Dict, List, Set, Tuple
|
| 29 |
+
import sphinx.ext.autodoc
|
| 30 |
+
import sphinx.ext.autosummary.generate as ag
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# pylint:disable=all
|
| 34 |
+
def generate_autosummary_content(
|
| 35 |
+
name: str,
|
| 36 |
+
obj: Any,
|
| 37 |
+
parent: Any,
|
| 38 |
+
template: ag.AutosummaryRenderer,
|
| 39 |
+
template_name: str,
|
| 40 |
+
imported_members: bool,
|
| 41 |
+
app: Any,
|
| 42 |
+
recursive: bool,
|
| 43 |
+
context: Dict,
|
| 44 |
+
modname: str = None,
|
| 45 |
+
qualname: str = None,
|
| 46 |
+
) -> str:
|
| 47 |
+
doc = ag.get_documenter(app, obj, parent)
|
| 48 |
+
|
| 49 |
+
def skip_member(obj: Any, name: str, objtype: str) -> bool:
|
| 50 |
+
try:
|
| 51 |
+
return app.emit_firstresult(
|
| 52 |
+
'autodoc-skip-member', objtype, name, obj, False, {}
|
| 53 |
+
)
|
| 54 |
+
except Exception as exc:
|
| 55 |
+
ag.logger.warning(
|
| 56 |
+
__(
|
| 57 |
+
'autosummary: failed to determine %r to be documented, '
|
| 58 |
+
'the following exception was raised:\n%s'
|
| 59 |
+
),
|
| 60 |
+
name,
|
| 61 |
+
exc,
|
| 62 |
+
type='autosummary',
|
| 63 |
+
)
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
def get_class_members(obj: Any) -> Dict[str, Any]:
|
| 67 |
+
members = sphinx.ext.autodoc.get_class_members(
|
| 68 |
+
obj, [qualname], ag.safe_getattr
|
| 69 |
+
)
|
| 70 |
+
return {name: member.object for name, member in members.items()}
|
| 71 |
+
|
| 72 |
+
def get_module_members(obj: Any) -> Dict[str, Any]:
|
| 73 |
+
members = {}
|
| 74 |
+
for name in ag.members_of(obj, app.config):
|
| 75 |
+
try:
|
| 76 |
+
members[name] = ag.safe_getattr(obj, name)
|
| 77 |
+
except AttributeError:
|
| 78 |
+
continue
|
| 79 |
+
return members
|
| 80 |
+
|
| 81 |
+
def get_all_members(obj: Any) -> Dict[str, Any]:
|
| 82 |
+
if doc.objtype == 'module':
|
| 83 |
+
return get_module_members(obj)
|
| 84 |
+
elif doc.objtype == 'class':
|
| 85 |
+
return get_class_members(obj)
|
| 86 |
+
return {}
|
| 87 |
+
|
| 88 |
+
def get_members(
|
| 89 |
+
obj: Any,
|
| 90 |
+
types: Set[str],
|
| 91 |
+
include_public: List[str] = [],
|
| 92 |
+
imported: bool = True,
|
| 93 |
+
) -> Tuple[List[str], List[str]]:
|
| 94 |
+
items: List[str] = []
|
| 95 |
+
public: List[str] = []
|
| 96 |
+
|
| 97 |
+
all_members = get_all_members(obj)
|
| 98 |
+
for name, value in all_members.items():
|
| 99 |
+
documenter = ag.get_documenter(app, value, obj)
|
| 100 |
+
if documenter.objtype in types:
|
| 101 |
+
# skip imported members if expected
|
| 102 |
+
if imported or getattr(value, '__module__', None) == obj.__name__:
|
| 103 |
+
skipped = skip_member(value, name, documenter.objtype)
|
| 104 |
+
if skipped is True:
|
| 105 |
+
pass
|
| 106 |
+
elif skipped is False:
|
| 107 |
+
# show the member forcedly
|
| 108 |
+
items.append(name)
|
| 109 |
+
public.append(name)
|
| 110 |
+
else:
|
| 111 |
+
items.append(name)
|
| 112 |
+
if name in include_public or not name.startswith('_'):
|
| 113 |
+
# considers member as public
|
| 114 |
+
public.append(name)
|
| 115 |
+
return public, items
|
| 116 |
+
|
| 117 |
+
def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]:
|
| 118 |
+
"""Find module attributes with docstrings."""
|
| 119 |
+
attrs, public = [], []
|
| 120 |
+
try:
|
| 121 |
+
analyzer = ag.ModuleAnalyzer.for_module(name)
|
| 122 |
+
attr_docs = analyzer.find_attr_docs()
|
| 123 |
+
for namespace, attr_name in attr_docs:
|
| 124 |
+
if namespace == '' and attr_name in members:
|
| 125 |
+
attrs.append(attr_name)
|
| 126 |
+
if not attr_name.startswith('_'):
|
| 127 |
+
public.append(attr_name)
|
| 128 |
+
except ag.PycodeError:
|
| 129 |
+
pass # give up if ModuleAnalyzer fails to parse code
|
| 130 |
+
return public, attrs
|
| 131 |
+
|
| 132 |
+
def get_modules(obj: Any) -> Tuple[List[str], List[str]]:
|
| 133 |
+
items: List[str] = []
|
| 134 |
+
for _, modname, _ispkg in ag.pkgutil.iter_modules(obj.__path__):
|
| 135 |
+
fullname = name + '.' + modname
|
| 136 |
+
try:
|
| 137 |
+
module = ag.import_module(fullname)
|
| 138 |
+
if module and hasattr(module, '__sphinx_mock__'):
|
| 139 |
+
continue
|
| 140 |
+
except ImportError:
|
| 141 |
+
pass
|
| 142 |
+
|
| 143 |
+
items.append(fullname)
|
| 144 |
+
public = [x for x in items if not x.split('.')[-1].startswith('_')]
|
| 145 |
+
return public, items
|
| 146 |
+
|
| 147 |
+
ns: Dict[str, Any] = {}
|
| 148 |
+
ns.update(context)
|
| 149 |
+
|
| 150 |
+
if doc.objtype == 'module':
|
| 151 |
+
scanner = ag.ModuleScanner(app, obj)
|
| 152 |
+
ns['members'] = scanner.scan(imported_members)
|
| 153 |
+
ns['functions'], ns['all_functions'] = get_members(
|
| 154 |
+
obj, {'function'}, imported=imported_members
|
| 155 |
+
)
|
| 156 |
+
ns['classes'], ns['all_classes'] = get_members(
|
| 157 |
+
obj, {'class'}, imported=imported_members
|
| 158 |
+
)
|
| 159 |
+
ns['exceptions'], ns['all_exceptions'] = get_members(
|
| 160 |
+
obj, {'exception'}, imported=imported_members
|
| 161 |
+
)
|
| 162 |
+
ns['attributes'], ns['all_attributes'] = get_module_attrs(ns['members'])
|
| 163 |
+
ispackage = hasattr(obj, '__path__')
|
| 164 |
+
if ispackage and recursive:
|
| 165 |
+
ns['modules'], ns['all_modules'] = get_modules(obj)
|
| 166 |
+
elif doc.objtype == 'class':
|
| 167 |
+
ns['members'] = dir(obj)
|
| 168 |
+
ns['inherited_members'] = set(dir(obj)) - set(obj.__dict__.keys())
|
| 169 |
+
ns['methods'], ns['all_methods'] = get_members(
|
| 170 |
+
obj, {'method'}, ['__init__']
|
| 171 |
+
)
|
| 172 |
+
ns['attributes'], ns['all_attributes'] = get_members(
|
| 173 |
+
obj, {'attribute', 'property'}
|
| 174 |
+
)
|
| 175 |
+
ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys())
|
| 176 |
+
|
| 177 |
+
if modname is None or qualname is None:
|
| 178 |
+
modname, qualname = ag.split_full_qualified_name(name)
|
| 179 |
+
|
| 180 |
+
if doc.objtype in ('method', 'attribute', 'property'):
|
| 181 |
+
ns['class'] = qualname.rsplit('.', 1)[0]
|
| 182 |
+
|
| 183 |
+
if doc.objtype in ('class',):
|
| 184 |
+
shortname = qualname
|
| 185 |
+
else:
|
| 186 |
+
shortname = qualname.rsplit('.', 1)[-1]
|
| 187 |
+
|
| 188 |
+
ns['fullname'] = name
|
| 189 |
+
ns['module'] = modname
|
| 190 |
+
ns['objname'] = qualname
|
| 191 |
+
ns['name'] = shortname
|
| 192 |
+
|
| 193 |
+
ns['objtype'] = doc.objtype
|
| 194 |
+
ns['underline'] = len(name) * '='
|
| 195 |
+
|
| 196 |
+
if template_name:
|
| 197 |
+
return template.render(template_name, ns)
|
| 198 |
+
else:
|
| 199 |
+
return template.render(doc.objtype, ns)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
ag.generate_autosummary_content = generate_autosummary_content
|
t5x-main/docs/contributions.md
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributions
|
| 2 |
+
|
| 3 |
+
T5X was developed as part of the T5 Infrastructure effort at Google Research.
|
| 4 |
+
|
| 5 |
+
Adam Roberts founded and leads the project, designed and wrote much of `seqio`
|
| 6 |
+
and `t5x`, and co-authored the
|
| 7 |
+
[T5X and SeqIO paper](https://arxiv.org/abs/2203.17189). Hyung Won Chung
|
| 8 |
+
designed and wrote much of `t5x`, led its open sourcing, and co-authored the
|
| 9 |
+
paper. Anselm Levskaya built the initial prototype for `t5x` and wrote much of
|
| 10 |
+
the code. Gaurav Mishra leads `seqio`, implemented deterministic pipelines, and
|
| 11 |
+
co-authored the paper. James Bradbury implemented partitioning in `t5x` and
|
| 12 |
+
co-wrote the paper.
|
| 13 |
+
|
| 14 |
+
Daniel Andor, Sharan Narang, Brian Lester, Colin Gaffney, Afroz Mohiuddin,
|
| 15 |
+
Curtis Hawthorne, Aitor Lewkowycz, Alex Salcianu, Marc van Zee, Jacob Austin,
|
| 16 |
+
Sebastian Good-man, Livio Baldini Soares, Haitang Hu, Sasha Tsvyashchenko,
|
| 17 |
+
Aakanksha Chowdhery, Jasmijn Bastings, Jannis Bulian, Xavier Garcia, Jianmo Ni,
|
| 18 |
+
Andrew Chen, Kathleen Kenealy, Kehang Han, Jonathan H. Clark, Stephan Lee, Dan
|
| 19 |
+
Garrette, and James Lee-Thorp made substantial code contributions.
|
| 20 |
+
|
| 21 |
+
Colin Raffel and Noam Shazeer helped design `seqio`. Marvin Ritter advised on
|
| 22 |
+
deterministic pipelines and the use of CLU Metrics. Maarten Bosma helped design
|
| 23 |
+
deterministic pipelines. Jeremy Maitin-Shepard advised on the use of
|
| 24 |
+
TensorStore. Alexandre Passos and Ryan Sepassi advised on overall technical
|
| 25 |
+
design.
|
| 26 |
+
|
| 27 |
+
Noah Fiedel is a member of the leadership team, contributed to the high level
|
| 28 |
+
design and roadmap, and co-wrote the paper. Mark Omernick, Brennan Saeta, Ryan
|
| 29 |
+
Sepassi, Alexander Spiridonov (Product Manager), and Josh Newlan (Technical
|
| 30 |
+
Program Manager) are members of the leadership team and co-wrote the paper.
|
| 31 |
+
Andrea Gesmundo is a member of the leadership team and contributed to the
|
| 32 |
+
internal infrastructure component.
|
| 33 |
+
|
| 34 |
+
Thanks to the many other contributors to the project: Ian Simon, Reiner Pope,
|
| 35 |
+
Vincent Zhao, Pierre Ruyssen, Linting Xue, Junwhan Ahn, Barret Zoph, David
|
| 36 |
+
Dohan, Masumi Parekh, Chang Lan, Frederick Liu, Julien Amelot, Luheng He, Fede
|
| 37 |
+
Lebron, RebeccaChen, Anosh Raj, Mandy Guo, Ethan Dyer, Mihai Tiuca, Hongkun Yu,
|
| 38 |
+
Kevin Brooks, David Soergel, Kelvin Guu, Joshua Ainslie, Luyao Xu, Ji Ma, Josh
|
| 39 |
+
Gardner, Daphne Ippolito, Peter Hawkins, Bo Pang, Marc Rasi, Wei Li, Wenhu Chen,
|
| 40 |
+
Iulia Turc, John Wieting, Alex Passos, Zonglin Li, Katie Everett, Olivier
|
| 41 |
+
Bachem, Francesco Piccinno, Jakub Adamek, Jonathan Heek, Parker Schuh, Hexiang
|
| 42 |
+
Hu, Du Phan, Max Moroz, David Miller, Ryan Doherty, David Elworthy, Alfonso
|
| 43 |
+
Casta ̃no, Julian Eisenschlos, Vlad-Doru Ion, Lucas Dixon, Ron Shapiro, Dinghua
|
| 44 |
+
Li, Aaron Parisi, Xi Chen, Nan Ding, Chung-ching Chang, Timothy Dozat, Natalia
|
| 45 |
+
Ponomareva, Delesley Hutchins, Ankush Garg, Yu-Han Liu, Mehrdad Khatir, Costanza
|
| 46 |
+
Conforti, Philipp Keck, Rapha ̈el Marinier, Marie Pellat, Raghuram Vadapalli,
|
| 47 |
+
Joshua Maynez, Yi Tay, Xihui Wu, David Belanger, Luke Metz, Dan Zheng, Deepti
|
| 48 |
+
Bhatia, Hariharan Shanmugavadivel, Rewon Child, Rigel Swavely, Mihir Sanjay
|
| 49 |
+
Kale, Arash Afkanpour, Roberto Rama, Juro Gottweis, Jonathan Herzig, Yilei Yang,
|
| 50 |
+
Elias Mizan, Pedram Pejman, Jiayu Ye, Smit Sanghavi, Rahul Joshi, Ziqiang Feng,
|
| 51 |
+
Charles Sutton, Weikang Zhou, Liam Fedus, Shanqing Cai, Ginger Perng, Yash
|
| 52 |
+
Katariya, Urvashi Khandelwal, Sebastian Gehrmann, Edward Loper, Tianze Shi, Luke
|
| 53 |
+
Vilnis, Amelia Archer, Tom Weingarten, David Zats, Murtaza Dhuliawala, Xin Xie,
|
| 54 |
+
Sahil Dua, Andr ́e SusanoPinto, Piotr Padlewski, Sascha Rothe, Erik Aas, Felix
|
| 55 |
+
Stahlberg, Ken Durden, Christina Sorokin, Jaehoon Lee, Roy Frostig, Jacob
|
| 56 |
+
Devlin, Jorge Gonzalez Mendez, Deepak Ramachandran, Santiago Ontanon, Karthik
|
| 57 |
+
Raman, Yi Sun, Ali Elqursh, Reuben La Haye,Adam Fahrenkopf, Alex Polozov, Vinay
|
| 58 |
+
Ramasesh, Ian Tenney.
|
| 59 |
+
|
| 60 |
+
Thanks to NVIDIA for GPU contributions: Sahil Jain, Terry Kong, Yu-Hang Tang,
|
| 61 |
+
Ming Huang, Frederic Bastien, Sharath Turuvekere Sreenivas, Xiaowei Ren, Ryan Jeng,
|
| 62 |
+
Reese Wang
|
| 63 |
+
|
| 64 |
+
Thanks to Douglas Eck and Zoubin Ghahramani for sponsoring the project.
|
t5x-main/docs/index.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# T5X
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
Note: T5X is community-supported since ~2023. For critical use cases, consider
|
| 5 |
+
using libraries like TuneLab (go/tunelab) and Gemax Prod (go/gemax-prod). See
|
| 6 |
+
https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-to-gemax-prod for useful tips on transitioning.
|
| 7 |
+
|
| 8 |
+
## Overview
|
| 9 |
+
|
| 10 |
+
T5X is a modular, composable, research-friendly framework for high-performance,
|
| 11 |
+
configurable, self-service training, evaluation, and inference of sequence
|
| 12 |
+
models (starting with language) at many scales.
|
| 13 |
+
|
| 14 |
+
It is essentially a new and improved implementation of the
|
| 15 |
+
[T5 codebase](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md) (based on Mesh TensorFlow) in JAX and Flax. To learn
|
| 16 |
+
more, see the [T5X Paper](https://arxiv.org/abs/2203.17189).
|
| 17 |
+
|
| 18 |
+
## Getting Started
|
| 19 |
+
|
| 20 |
+
Here are some quick tutorials to help you get started with common use-cases on
|
| 21 |
+
T5X:
|
| 22 |
+
|
| 23 |
+
#### [Introductory Colabs](tutorials.md)
|
| 24 |
+
|
| 25 |
+
If you are new to T5X, we recommend starting with our introductory Colab series,
|
| 26 |
+
which introduces core concepts of both T5X and SeqIO. More colabs will be added
|
| 27 |
+
to this series regularly!
|
| 28 |
+
|
| 29 |
+
#### [Fine-tuning a model](usage/finetune.md)
|
| 30 |
+
|
| 31 |
+
This tutorial outlines the steps to fine-tune an existing pre-trained model with
|
| 32 |
+
T5X on common downstream Tasks/Mixtures available on SeqIO. This is one of the
|
| 33 |
+
simplest and most common use cases of T5X. If you're new to T5X, this tutorial
|
| 34 |
+
is the recommended starting point.
|
| 35 |
+
|
| 36 |
+
#### [Running evaluation on a model](usage/eval.md)
|
| 37 |
+
|
| 38 |
+
This tutorial outlines the steps to evaluate a model with T5X on downstream
|
| 39 |
+
Tasks/Mixtures defined in SeqIO.
|
| 40 |
+
|
| 41 |
+
#### [Running inference on a model](usage/infer.md)
|
| 42 |
+
|
| 43 |
+
This tutorial outlines the steps to run inference on a model with T5X.
|
| 44 |
+
|
| 45 |
+
#### [Training a model from scratch](usage/pretrain.md)
|
| 46 |
+
|
| 47 |
+
This tutorial outlines the steps to pretrain a model with T5X on Tasks/Mixtures
|
| 48 |
+
defined in SeqIO.
|
| 49 |
+
|
| 50 |
+
#### [Gin Primer](usage/gin.md)
|
| 51 |
+
|
| 52 |
+
This tutorial provides a quick introduction to Gin, a lightweight configuration
|
| 53 |
+
framework for Python that is used to configure training, eval and inference jobs
|
| 54 |
+
on T5X.
|
| 55 |
+
|
| 56 |
+
#### [Partitioning Primer](usage/partitioning.md)
|
| 57 |
+
|
| 58 |
+
This tutorial provides background on what model and data partitioning are and
|
| 59 |
+
how it can be configured in T5X.
|
| 60 |
+
|
| 61 |
+
#### [Metrics Overview](usage/metrics.md)
|
| 62 |
+
|
| 63 |
+
This tutorial provides an overview of how metrics can be used and customized to
|
| 64 |
+
evaluate T5X models.
|
| 65 |
+
|
t5x-main/docs/index.rst
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
******************************
|
| 2 |
+
T5X
|
| 3 |
+
******************************
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
T5X is a modular, composable, research-friendly framework for high-performance,
|
| 7 |
+
configurable, self-service training, evaluation, and inference of sequence
|
| 8 |
+
models (starting with language) at many scales.
|
| 9 |
+
|
| 10 |
+
It is essentially a new and improved implementation of the
|
| 11 |
+
`T5 codebase <https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md>`__
|
| 12 |
+
(based on Mesh TensorFlow) in JAX and Flax. To learn more, see the
|
| 13 |
+
`T5X Paper <https://arxiv.org/abs/2203.17189>`__.
|
| 14 |
+
|
| 15 |
+
.. toctree::
|
| 16 |
+
:maxdepth: 2
|
| 17 |
+
:caption: Table of Contents
|
| 18 |
+
|
| 19 |
+
Quick Start <overview>
|
| 20 |
+
Tutorials <tutorials>
|
| 21 |
+
Usage Guides <usage/index>
|
| 22 |
+
Models <models>
|
| 23 |
+
api_reference/index
|
| 24 |
+
contributions
|
t5x-main/docs/models.md
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Models
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
This page lists the available pre-trained T5 models. To use a pre-trained model,
|
| 5 |
+
you need a Gin config file that defines the model params, and the model
|
| 6 |
+
checkpoint to load from. For your convenience, TensorFlow checkpoints and Gin
|
| 7 |
+
configs for common T5 pre-trained models have been made available for use in
|
| 8 |
+
T5X. Following is a list of these pre-trained models and their Gin and
|
| 9 |
+
checkpoint locations.
|
| 10 |
+
|
| 11 |
+
+ All checkpoints:
|
| 12 |
+
[`gs://t5-data/pretrained_models/t5x/`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/)
|
| 13 |
+
+ All Gin files:
|
| 14 |
+
[`t5x/configs/models/`](https://github.com/google-research/t5x/blob/main/t5x/configs/)
|
| 15 |
+
|
| 16 |
+
### Selecting a model:
|
| 17 |
+
|
| 18 |
+
Publicly Available Models:
|
| 19 |
+
|
| 20 |
+
Model | Use Case
|
| 21 |
+
---------------------------------------------------- | --------
|
| 22 |
+
[T5 1.1](#t5-11-checkpoints) | Improved T5, recommended for most research. English only.
|
| 23 |
+
[T5](#t5-checkpoints) | The original T5 work for reproducibility. English only.
|
| 24 |
+
[T5 1.1 LM-Adapted](#t5-11-lm-adapted-checkpoints) | Trained for 100k additional steps on the LM objective, per [prompt tuning paper](https://arxiv.org/abs/2104.08691).
|
| 25 |
+
[mT5](#mt5-checkpoints) | Multilingual T5. Recommended for multilingual research. Note that at smaller scales (at least through XL), mT5 performance is lower than T5 on English tasks.
|
| 26 |
+
[mT5 LM-Adapted](#mt5-lm-adapted-checkpoints) | Trained for 100k additional steps on the LM objective, per [zero-shot cross-lingual generation (XGen) paper](https://arxiv.org/abs/2205.12647).
|
| 27 |
+
[umT5](#umt5-checkpoints) | umT5, an updated mT5 model trained using a more uniform language distribution, per [the UniMax paper](https://openreview.net/forum?id=kXwdL1cWOAi).
|
| 28 |
+
[ByT5](#byt5-checkpoints) | ByT5. A "token-free" model that uses UTF-8 bytes for input and output. Recommended for tasks involving word-internal phenomena such as spelling, pronunciation, or morphology.
|
| 29 |
+
[LongT5](#longt5-checkpoints) | Recommended checkpoints to fine-tune for long input sequence tasks
|
| 30 |
+
[MoE](#mixture-of-experts-moe-checkpoints) | Useful for MoE experimentation.
|
| 31 |
+
[Flan-T5](#flan-t5-checkpoints) | General purpose T5 checkpoints for few-shot and finetuning. We recommend Flan-T5 over vanilla T5 and T5 LM-adapted
|
| 32 |
+
[UL2](#ul2-checkpoints) | Checkpoints for 20B pretrained and FLAN-based instruction-tuned models using the UL2 objective from [UL2 paper](https://arxiv.org/abs/2205.05131)
|
| 33 |
+
[BigScience](#bigscience-checkpoints) | Checkpoints from the [BigScience paper](https://arxiv.org/abs/2204.05832)
|
| 34 |
+
[FLIP](#flip-checkpoints) | Language-Image models trained with an alternative to CLIP, presented in the [FLIP paper](https://arxiv.org/abs/2212.00794)
|
| 35 |
+
[RankGen](#rankgen-checkpoints) | 1.2B parameter encoder model for English to score model generations given a prefix for decoding from the [RankGen paper](https://arxiv.org/abs/2205.09726)
|
| 36 |
+
[Dipper](#dipper-checkpoints) | 11B parameter paraphrase generation model from the [Dipper paper](https://arxiv.org/abs/2303.13408)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
### Public Research Models
|
| 40 |
+
|
| 41 |
+
#### T5 Checkpoints
|
| 42 |
+
|
| 43 |
+
These are the checkpoints used in the paper [Exploring the Limits of Transfer
|
| 44 |
+
Learning with a Unified Text-to-Text
|
| 45 |
+
Transformer](https://arxiv.org/abs/1910.10683). They are encoder-decoder models
|
| 46 |
+
pre-trained on [C4](https://www.tensorflow.org/datasets/catalog/c4) with a "span
|
| 47 |
+
corruption" denoising objective, in addition to a mixture of downstream tasks
|
| 48 |
+
including: GLUE, SuperGLUE, CNN/Daily Mail, SQuAD, and WMT.
|
| 49 |
+
|
| 50 |
+
**Vocabulary:**
|
| 51 |
+
[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
|
| 52 |
+
|
| 53 |
+
Model | Gin File Location | Checkpoint Location
|
| 54 |
+
-------- | ------------------------------------------------------------------------------ | -------------------
|
| 55 |
+
T5 Small | [t5_small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/small.gin) | [gs://t5-data/pretrained_models/t5x/t5_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_small)
|
| 56 |
+
T5 Base | [t5_base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/base.gin) | [gs://t5-data/pretrained_models/t5x/t5_base/checkpoint_999900](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_base)
|
| 57 |
+
T5 Large | [t5_large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/large.gin) | [gs://t5-data/pretrained_models/t5x/t5_large/checkpoint_1000700](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_large)
|
| 58 |
+
T5 3B | [t5_3B.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/3B.gin) | [gs://t5-data/pretrained_models/t5x/t5_3B/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_3B)
|
| 59 |
+
T5 11B | [t5_11B.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/11B.gin) | [gs://t5-data/pretrained_models/t5x/t5_11B/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_11B)
|
| 60 |
+
|
| 61 |
+
#### T5 1.1 Checkpoints
|
| 62 |
+
|
| 63 |
+
These are similar to the models from [Exploring the Limits of Transfer Learning
|
| 64 |
+
with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683), but
|
| 65 |
+
with the following improvements:
|
| 66 |
+
|
| 67 |
+
* GEGLU activation in feed-forward hidden layer, rather than ReLU - see
|
| 68 |
+
https://arxiv.org/abs/2002.05202 .
|
| 69 |
+
* Dropout was turned off in pre-training (quality win). Dropout should be
|
| 70 |
+
re-enabled during fine-tuning.
|
| 71 |
+
* Pre-trained on C4 only without mixing in the downstream tasks.
|
| 72 |
+
* no parameter sharing between embedding and classifier layer
|
| 73 |
+
* "xl" and "xxl" replace "3B" and "11B". The model shapes are a bit
|
| 74 |
+
different - larger d_model and smaller num_heads and d_ff.
|
| 75 |
+
|
| 76 |
+
For English-language, sequence-to-sequence-style tasks (ones where the goal is
|
| 77 |
+
to map from an input text sequence to a target sequence) these are usually the
|
| 78 |
+
best models to fine-tune.
|
| 79 |
+
|
| 80 |
+
**Vocabulary:**
|
| 81 |
+
[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
|
| 82 |
+
|
| 83 |
+
Model | Gin File Location | Checkpoint Location
|
| 84 |
+
------------ | ---------------------------------------------------------------------------------- | -------------------
|
| 85 |
+
T5 1.1 Small | [t5_1_1/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_small)
|
| 86 |
+
T5 1.1 Base | [t5_1_1/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/base.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_base)
|
| 87 |
+
T5 1.1 Large | [t5_1_1_large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/large.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_large)
|
| 88 |
+
T5 1.1 XL | [t5_1_1_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xl.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_xl)
|
| 89 |
+
T5 1.1 XXL | [t5_1_1_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xxl.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_xxl)
|
| 90 |
+
|
| 91 |
+
#### T5 1.1 LM-Adapted Checkpoints
|
| 92 |
+
|
| 93 |
+
These "LM-adapted" models are initialized from T5 1.1 (above) and trained for an
|
| 94 |
+
additional 100K steps on the LM objective discussed in the
|
| 95 |
+
[T5 paper](https://arxiv.org/abs/1910.10683). This adaptation improves the
|
| 96 |
+
ability of the model to be used for
|
| 97 |
+
[prompt tuning](https://arxiv.org/abs/2104.08691). These checkpoints were also
|
| 98 |
+
used within the BigScience [T0](https://arxiv.org/abs/2110.08207) project.
|
| 99 |
+
|
| 100 |
+
**Vocabulary:**
|
| 101 |
+
[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
|
| 102 |
+
|
| 103 |
+
Model | Gin File Location | Checkpoint Location
|
| 104 |
+
-------------------- | ------------------------------------------------------------------------------------------------------------------- | -------------------
|
| 105 |
+
T5 1.1 LM-100K Small | [t5_1_1_small.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin) | [t5_1_1_lm100k_small/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_small)
|
| 106 |
+
T5 1.1 LM-100K Base | [t5_1_1_base.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_base.gin) | [t5_1_1_lm100k_base/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_base)
|
| 107 |
+
T5 1.1 LM-100K Large | [t5_1_1_large.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_large.gin) | [t5_1_1_lm100k_large/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_large)
|
| 108 |
+
T5 1.1 LM-100K XL | [t5_1_1_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_xl.gin) | [t5_1_1_lm100k_xl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl)
|
| 109 |
+
T5 1.1 LM-100K XXL | [t5_1_1_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_xxl.gin) | [t5_1_1_lm100k_xxl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_xxl)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
#### mT5 Checkpoints
|
| 113 |
+
|
| 114 |
+
These are the checkpoints used in the paper
|
| 115 |
+
[mT5: A Massively Multilingual Pre-trained Text-to-Text Transformer](https://aclanthology.org/2021.naacl-main.41/).
|
| 116 |
+
They are encoder-decoder models trained on
|
| 117 |
+
[multilingual C4](https://www.tensorflow.org/datasets/catalog/c4#c4multilingual)
|
| 118 |
+
with a denoising objective. These are the best checkpoints to fine-tune for
|
| 119 |
+
non-English sequence-to-sequence tasks.
|
| 120 |
+
|
| 121 |
+
**Vocabulary:**
|
| 122 |
+
[mc4.250000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/mc4.250000.100extra)
|
| 123 |
+
|
| 124 |
+
Model | Gin File Location | Checkpoint Location
|
| 125 |
+
--------- | ---------------------------------------------------------------------------- | -------------------
|
| 126 |
+
mT5 Small | [mt5/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/small.gin) | [gs://t5-data/pretrained_models/t5x/mt5_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_small)
|
| 127 |
+
mT5 Base | [mt5/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/base.gin) | [gs://t5-data/pretrained_models/t5x/mt5_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_base)
|
| 128 |
+
mT5 Large | [mt5/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/large.gin) | [gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_large)
|
| 129 |
+
mT5 XL | [mt5/xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xl.gin) | [gs://t5-data/pretrained_models/t5x/mt5_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_xl)
|
| 130 |
+
mT5 XXL | [mt5/xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xxl.gin) | [gs://t5-data/pretrained_models/t5x/mt5_xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_xxl)
|
| 131 |
+
|
| 132 |
+
#### mT5 LM-Adapted Checkpoints
|
| 133 |
+
|
| 134 |
+
These are the checkpoints released as part of the
|
| 135 |
+
[zero-shot cross-lingual generation (XGen) paper](https://arxiv.org/abs/2205.12647).
|
| 136 |
+
|
| 137 |
+
These "LM-adapted" models are initialized from mT5 (above) and trained for an
|
| 138 |
+
additional 100K steps on the LM objective discussed in the
|
| 139 |
+
[T5 paper](https://arxiv.org/abs/1910.10683).
|
| 140 |
+
|
| 141 |
+
This adaptation improves the ability of the model to be used for
|
| 142 |
+
[prompt tuning](https://arxiv.org/abs/2104.08691).
|
| 143 |
+
|
| 144 |
+
**Vocabulary:**
|
| 145 |
+
[mc4.250000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/mc4.250000.100extra)
|
| 146 |
+
|
| 147 |
+
Model | Gin File Location | Checkpoint Location
|
| 148 |
+
-------------------- | ---------------------------------------------------------------------------- | -------------------
|
| 149 |
+
mT5 LM-Adapted Small | [mt5/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/small.gin) | [mt5_lm_adapted/small/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/small/checkpoint_1100000)
|
| 150 |
+
mT5 LM-Adapted Base | [mt5/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/base.gin) | [mt5_lm_adapted/base/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/base/checkpoint_1100000)
|
| 151 |
+
mT5 LM-Adapted Large | [mt5/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/large.gin) | [mt5_lm_adapted/large/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/large/checkpoint_1100000)
|
| 152 |
+
mT5 LM-Adapted XL | [mt5/xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xl.gin) | [mt5_lm_adapted/xl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/xl/checkpoint_1100000)
|
| 153 |
+
mT5 LM-Adapted XXL | [mt5/xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xxl.gin) | [mt5_lm_adapted/xxl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/xxl/checkpoint_1100000)
|
| 154 |
+
|
| 155 |
+
#### umT5 Checkpoints
|
| 156 |
+
|
| 157 |
+
These are the checkpoints described in the paper [UniMax: Fairer and More
|
| 158 |
+
Effective Language Sampling for Large-Scale Multilingual
|
| 159 |
+
Pretraining](https://openreview.net/forum?id=kXwdL1cWOAi). umT5 is similar to
|
| 160 |
+
mT5 (see above); both are multilingual encoder-decoder models ranging from 300M
|
| 161 |
+
to 13B parameters, trained on the mC4 corpus using a denoising objective. umT5
|
| 162 |
+
is trained on a fresher version of the mC4 corpus (3.1.0), and with a more
|
| 163 |
+
uniform language balancing strategy.
|
| 164 |
+
|
| 165 |
+
**Vocabulary:** [umt5.256000](https://console.cloud.google.com/storage/browser/t5-data/vocabs/umt5.256000)
|
| 166 |
+
|
| 167 |
+
Model | Gin File Location | Checkpoint Location
|
| 168 |
+
---------- | --------------------------------------------------------------------------------------------------------- | -------------------
|
| 169 |
+
umT5 Small | [umt5/pretrain_small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_small.gin) | [umt5/small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/small/checkpoint_1000000)
|
| 170 |
+
umT5 Base | [umt5/pretrain_base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_base.gin) | [umt5/base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/base/checkpoint_1000000)
|
| 171 |
+
umT5 XL | [umt5/pretrain_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_xl.gin) | [umt5/xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/xl/checkpoint_1000000)
|
| 172 |
+
umT5 XXL | [umt5/pretrain_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_xxl.gin) | [umt5/xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/xxl/checkpoint_1000000)
|
| 173 |
+
|
| 174 |
+
#### ByT5 Checkpoints
|
| 175 |
+
|
| 176 |
+
These are the checkpoints used in the paper
|
| 177 |
+
[ByT5: Towards a Token-Free Future with Pre-trained Byte-to-Byte Models](https://aclanthology.org/2022.tacl-1.17/).
|
| 178 |
+
They are similar to mT5 (above), but are "token-free", processing text as raw
|
| 179 |
+
UTF-8 bytes, as opposed to using a pretrained subword vocabulary. These models
|
| 180 |
+
are more robust to character-level noise, and outperform parameter-matched mT5
|
| 181 |
+
models in many settings, particularly on word-level tasks sensitive to spelling,
|
| 182 |
+
pronunciation, or morphology. However inference is significantly slower, up to
|
| 183 |
+
10x depending on the task.
|
| 184 |
+
|
| 185 |
+
**Vocabulary:** None
|
| 186 |
+
|
| 187 |
+
Model | Gin File Location | Checkpoint Location
|
| 188 |
+
---------- | ------------------------------------------------------------------------------ | -------------------
|
| 189 |
+
ByT5 Small | [byt5/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/small.gin) | [gs://t5-data/pretrained_models/t5x/byt5_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_small)
|
| 190 |
+
ByT5 Base | [byt5/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/base.gin) | [gs://t5-data/pretrained_models/t5x/byt5_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_base)
|
| 191 |
+
ByT5 Large | [byt5/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/large.gin) | [gs://t5-data/pretrained_models/t5x/byt5_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_large)
|
| 192 |
+
ByT5 XL | [byt5/xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/xl.gin) | [gs://t5-data/pretrained_models/t5x/byt5_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_xl)
|
| 193 |
+
ByT5 XXL | [byt5/xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/xxl.gin) | [gs://t5-data/pretrained_models/t5x/byt5_xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_xxl)
|
| 194 |
+
|
| 195 |
+
#### LongT5 Checkpoints
|
| 196 |
+
|
| 197 |
+
These are the checkpoints used in the paper
|
| 198 |
+
[LongT5: Efficient Text-to-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916).
|
| 199 |
+
They are encoder-decoder models trained on
|
| 200 |
+
[C4](https://www.tensorflow.org/datasets/catalog/c4) using the PEGASUS Principle
|
| 201 |
+
Sentences Generation objective. These are the recommended checkpoints to
|
| 202 |
+
fine-tune for long input sequence tasks.
|
| 203 |
+
|
| 204 |
+
##### LongT5 Local Attention Checkpoints
|
| 205 |
+
|
| 206 |
+
The checkpoints below use local attention, which uses a sliding window to reduce
|
| 207 |
+
training time from quadratic (with regards to input length) to linear. These are
|
| 208 |
+
the recommended checkpoints to use for faster training/inference time.
|
| 209 |
+
|
| 210 |
+
**Vocabulary:**
|
| 211 |
+
[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
|
| 212 |
+
|
| 213 |
+
Model | Gin File Location | Checkpoint Location
|
| 214 |
+
---------------------------- | ------------------------------------------------------------------------------------------------------------------------------------- | -------------------
|
| 215 |
+
LongT5 Local Attention Base | [longt5/models/longt5_1_1_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_base.gin) | [gs://t5-data/pretrained_models/t5x/longt5/local_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/local_base)
|
| 216 |
+
LongT5 Local Attention Large | [longt5/models/longt5_1_1_large.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_large.gin) | [gs://t5-data/pretrained_models/t5x/longt5/local_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/local_large)
|
| 217 |
+
|
| 218 |
+
##### LongT5 Transient Global Attention Checkpoints
|
| 219 |
+
|
| 220 |
+
The checkpoints below use transient global attention, which introduces global
|
| 221 |
+
tokens at each encoder layer to allow tokens to interact with each other at
|
| 222 |
+
longer distances. These are the recommended checkpoints to use for increased
|
| 223 |
+
performance on long input sequence tasks.
|
| 224 |
+
|
| 225 |
+
**Vocabulary:**
|
| 226 |
+
[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
|
| 227 |
+
|
| 228 |
+
Model | Gin File Location | Checkpoint Location
|
| 229 |
+
------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------
|
| 230 |
+
LongT5 Base | [longt5/models/longt5_1_1_transient_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_transient_global_base.gin) | [gs://t5-data/pretrained_models/t5x/longt5/tglobal_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/tglobal_base)
|
| 231 |
+
LongT5 Large | [longt5/models/longt5_1_1_transient_large.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_transient_global_large.gin) | [gs://t5-data/pretrained_models/t5x/longt5/tglobal_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/tglobal_large)
|
| 232 |
+
LongT5 XL | [longt5/models/longt5_1_1_transient_xl.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_transient_global_xl.gin) | [gs://t5-data/pretrained_models/t5x/longt5/tglobal_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/tglobal_xl)
|
| 233 |
+
|
| 234 |
+
#### Mixture of Experts (MoE) Checkpoints
|
| 235 |
+
|
| 236 |
+
These MoE checkpoints need to be used with T5X MoE overrides -- specifically,
|
| 237 |
+
the MoeTrainer and the MoePjitPartitioner. For example, for fine-tuning, use the
|
| 238 |
+
[MoE fine-tune run config](https://github.com/google-research/t5x/blob/main/t5x/contrib/moe/configs/runs/finetune.gin).
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
##### Converted Mesh Tensorflow checkpoints
|
| 242 |
+
|
| 243 |
+
[Switch Transformer model](https://arxiv.org/abs/2101.03961).
|
| 244 |
+
|
| 245 |
+
**Vocabulary:**
|
| 246 |
+
[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
Model | Gin File Location | Checkpoint Location
|
| 250 |
+
---------------------------------------- | ------------------------------------------------------------------------------------------------------------ | -------------------
|
| 251 |
+
Switch Transformer Base 8 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e8/checkpoint_500100](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e8)
|
| 252 |
+
Switch Transformer Base 16 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e16/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e16)
|
| 253 |
+
Switch Transformer Base 32 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e32/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e32)
|
| 254 |
+
Switch Transformer Base 64 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e64/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e64)
|
| 255 |
+
Switch Transformer Base 128 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e128/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e128)
|
| 256 |
+
Switch Transformer Base 256 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e256/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e256)
|
| 257 |
+
Switch Transformer Large 128 Experts | [switch_large.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_large.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/large/e128/checkpoint_483100](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/large/e128)
|
| 258 |
+
Switch Transformer XXL 128 Experts | [switch_xxl.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_xxl.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/xxl/e128/checkpoint_634600](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/xxl/e128)
|
| 259 |
+
Switch Transformer C 2048 Experts (1.6T) | [switch_c.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_c.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/c/e2048/checkpoint_611800](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/c/e2048)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
#### Flan-T5 Checkpoints
|
| 263 |
+
|
| 264 |
+
These are the checkpoints released as part of the paper
|
| 265 |
+
[Scaling Instruction-Finetuned Language Models](https://arxiv.org/abs/2210.11416).
|
| 266 |
+
They were initialized from the
|
| 267 |
+
[T5 1.1 LM-Adapted](#t5-11-lm-adapted-checkpoints) and instruction-finetuned.
|
| 268 |
+
|
| 269 |
+
They significantly outperform the LM-adapted checkpoints. For example,
|
| 270 |
+
Flan-T5-XXL outperforms T5-LM-XXL by 26.6% absolute on the normalized average
|
| 271 |
+
score. It even outperforms a much larger PaLM 62B model on
|
| 272 |
+
[BigBench Hard](https://arxiv.org/abs/2210.09261) a set of challenging BigBench
|
| 273 |
+
benchmark.
|
| 274 |
+
|
| 275 |
+
Unlike the vanilla T5 checkpoints, these can be directly used for few-shot
|
| 276 |
+
prompting as well as standard finetuning. See
|
| 277 |
+
[Chung et al. 2022](https://arxiv.org/abs/2210.11416) for details.
|
| 278 |
+
|
| 279 |
+
Model | Gin File Location | Checkpoint Location
|
| 280 |
+
------------- | ---------------------------------------------------------------------------------- | -------------------
|
| 281 |
+
Flan-T5 Small | [t5_1_1/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_small/checkpoint_1198000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_small/checkpoint_1198000)
|
| 282 |
+
Flan-T5 Base | [t5_1_1/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/base.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_base/checkpoint_1184000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_base/checkpoint_1184000)
|
| 283 |
+
Flan-T5 Large | [t5_1_1_large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/large.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_large/checkpoint_1164000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_large/checkpoint_1164000)
|
| 284 |
+
Flan-T5 XL | [t5_1_1_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xl.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_xl/checkpoint_1138000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_xl/checkpoint_1138000)
|
| 285 |
+
Flan-T5 XXL | [t5_1_1_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xxl.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_xxl/checkpoint_1114000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_xxl/checkpoint_1114000)
|
| 286 |
+
|
| 287 |
+
#### UL2 Checkpoints
|
| 288 |
+
|
| 289 |
+
Checkpoints for 20B pretrained and FLAN-based instruction-tuned models using the
|
| 290 |
+
UL2 objective from [UL2 paper](https://arxiv.org/abs/2205.05131). Checkpoints
|
| 291 |
+
are released at
|
| 292 |
+
https://github.com/google-research/google-research/tree/master/ul2#checkpoints.
|
| 293 |
+
|
| 294 |
+
#### BigScience Checkpoints
|
| 295 |
+
|
| 296 |
+
Checkpoints from the [BigScience paper](https://arxiv.org/abs/2204.05832),
|
| 297 |
+
released at
|
| 298 |
+
https://github.com/bigscience-workshop/architecture-objective/tree/main#checkpoints.
|
| 299 |
+
|
| 300 |
+
#### FLIP Checkpoints
|
| 301 |
+
|
| 302 |
+
Language-Image models trained with an alternative to CLIP, presented in the
|
| 303 |
+
[FLIP paper](https://arxiv.org/abs/2212.00794). Checkpoints are released at
|
| 304 |
+
https://github.com/facebookresearch/flip#results-and-pre-trained-flip-models.
|
| 305 |
+
|
| 306 |
+
#### RankGen Checkpoints
|
| 307 |
+
|
| 308 |
+
1.2B parameter encoder model for English to score model generations given a
|
| 309 |
+
prefix for decoding from the [RankGen paper](https://arxiv.org/abs/2205.09726).
|
| 310 |
+
Checkpoints are released at
|
| 311 |
+
https://github.com/google-research/google-research/tree/master/rankgen.
|
| 312 |
+
|
| 313 |
+
#### Dipper Checkpoints
|
| 314 |
+
|
| 315 |
+
11B parameter paraphrase generation model from the
|
| 316 |
+
[Dipper paper](https://arxiv.org/abs/2303.13408). Checkpoints are released at
|
| 317 |
+
https://github.com/google-research/google-research/tree/master/dipper.
|
| 318 |
+
|
t5x-main/docs/overview.md
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
```{include} ../README.md
|
| 2 |
+
```
|
t5x-main/docs/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sphinx>=4.4.0
|
| 2 |
+
myst_parser>=0.16.1
|
| 3 |
+
myst_nb
|
| 4 |
+
sphinx-design
|
| 5 |
+
sphinx-book-theme
|
| 6 |
+
|
| 7 |
+
# Must install t5x itself for notebook execution and autodocs to work.
|
| 8 |
+
.
|
t5x-main/docs/t5x.png
ADDED
|
Git LFS Details
|
t5x-main/docs/tutorials.md
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# T5X Introductory Tutorial Series
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
## Overview
|
| 5 |
+
|
| 6 |
+
This series of guides is a self-contained introduction to T5X, a modular,
|
| 7 |
+
composable, research-friendly framework for high-performance, configurable,
|
| 8 |
+
self-service training, evaluation, and inference of sequence models (starting
|
| 9 |
+
with language) at many scales.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## How to Use These Guides
|
| 13 |
+
|
| 14 |
+
Most entries in this series are colab notebooks (click the blue banners to the
|
| 15 |
+
right of each heading below), allowing you to run our tutorial code
|
| 16 |
+
interactively. We encourage you to do that! Play around, change things, see what
|
| 17 |
+
happens!
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## T5X Guides
|
| 21 |
+
|
| 22 |
+
### Codelab 1: An Introduction to T5X
|
| 23 |
+
|
| 24 |
+
<a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in colab" style="float:left"/></a><br>
|
| 25 |
+
|
| 26 |
+
In this colab, you will learn about some of the basic T5X components and put
|
| 27 |
+
them to use to run training, inference, and evaluation on natural text inputs.
|
| 28 |
+
|
| 29 |
+
### Codelab 2: Training Deep Dive
|
| 30 |
+
|
| 31 |
+
<a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in colab" style="float:left"/></a><br>
|
| 32 |
+
|
| 33 |
+
In this colab, you will dive into how to restore T5X models from checkpoints and
|
| 34 |
+
run training, while also getting an introduction to the T5X trainer.
|
| 35 |
+
|
| 36 |
+
### Codelab 3: Inference Deep Dive
|
| 37 |
+
|
| 38 |
+
<a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in colab" style="float:left"/></a><br>
|
| 39 |
+
|
| 40 |
+
In this colab, you will dive into how the Interactive Model does decoding to
|
| 41 |
+
generate predictions and scores for a given input.
|
| 42 |
+
|
| 43 |
+
### Codelab 4: Evaluation Deep Dive
|
| 44 |
+
|
| 45 |
+
<a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in colab" style="float:left"/></a><br>
|
| 46 |
+
|
| 47 |
+
In this colab, you will dive into how the InteractiveModel takes a batch of
|
| 48 |
+
inputs and targets and runs evaluation to produce various metrics.
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
### More Colabs coming soon!
|
t5x-main/docs/usage/auxiliary.md
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auxiliary Job
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
## Introduction
|
| 5 |
+
|
| 6 |
+
This page outlines the steps needed to use the auxiliary job capabilities
|
| 7 |
+
available in T5X.
|
| 8 |
+
|
| 9 |
+
## Overview
|
| 10 |
+
|
| 11 |
+
There are a variety of situations in which running a single job is insufficient
|
| 12 |
+
or suboptimal. For example, consider the following scenarios:
|
| 13 |
+
|
| 14 |
+
+ You want to keep track of evaluation (`infer_eval` or `train_eval`) metrics
|
| 15 |
+
per checkpoint, but evaluation takes a very long time due to having a large
|
| 16 |
+
eval dataset, slow decoding, or multiple tasks to evaluate.
|
| 17 |
+
|
| 18 |
+
+ You want to finetune every checkpoint on a downstream task as you train.
|
| 19 |
+
|
| 20 |
+
+ You have customized evaluation code that you want to run on every checkpoint
|
| 21 |
+
as you train, but that does not naturally fit within a `seqio.Evaluator`
|
| 22 |
+
framework.
|
| 23 |
+
|
| 24 |
+
In cases like these, users can make use of the auxiliary job functionality. At a
|
| 25 |
+
high-level, the auxiliary job will launch a new job every time a new checkpoint
|
| 26 |
+
is saved. This new job can either re-use the `train.py` binary (e.g. for
|
| 27 |
+
continuous finetuning) or a different one. For example, this allows users to
|
| 28 |
+
perform continuous evaluation (using `eval.py`) without slowing down the
|
| 29 |
+
training job. We will provide detailed examples showing how to use the auxiliary
|
| 30 |
+
job for these use-cases.
|
| 31 |
+
|
| 32 |
+
When this new job is launched, the controller will replace four gin macros:
|
| 33 |
+
`MODEL_DIR`, `MIXTURE_OR_TASK_NAME`,`INITIAL_CHECKPOINT_PATH`, `TRAIN_STEPS`.
|
| 34 |
+
The second of these is set by the user-controlled flag (more on this below), and
|
| 35 |
+
the third one is equal to the last checkpoint seen. Aside from this, users are
|
| 36 |
+
free to modify the configuration as needed. Beyond gin macros, the auxiliary job
|
| 37 |
+
can also have different resource requirements, priority, and even cell placement
|
| 38 |
+
from the train job.
|
| 39 |
+
|
| 40 |
+
## Example 1: Separate evaluation job.
|
| 41 |
+
|
| 42 |
+
### Step 1: Choose a model architecture.
|
| 43 |
+
|
| 44 |
+
Similar to pretraining, we will need some gin configuration. For this example,
|
| 45 |
+
we will use the T5-1.1-Base model.
|
| 46 |
+
|
| 47 |
+
### Step 2: Choose a SeqIO Task/Mixture for training and evaluation.
|
| 48 |
+
|
| 49 |
+
In this example, we will use the classic task of English-French translation from
|
| 50 |
+
WMT14, which is conveniently available as a SeqIO task in the tasks file from
|
| 51 |
+
the T5 tasks under the name `'wmt_enfr14_v003'`.
|
| 52 |
+
|
| 53 |
+
### Step 3: Write a Gin config.
|
| 54 |
+
|
| 55 |
+
Unlike pretraining or finetuning, we will need two gin files for this setup: one
|
| 56 |
+
for the training job, and one for the auxiliary job. The train gin file will
|
| 57 |
+
have the same requirements as the gin file for pretraining or finetuning. The
|
| 58 |
+
auxiliary job gin file can leverage these gin files or be its own independent
|
| 59 |
+
gin file, depending on the user’s choice. For this example, we will make a new
|
| 60 |
+
gin which is mostly a wrapper around `pretrain.gin` with some additional
|
| 61 |
+
hardcoded features. We will use this gin file for the train job and `eval.gin`
|
| 62 |
+
for the auxiliary job.
|
| 63 |
+
|
| 64 |
+
### Step 4: Launch your experiment.
|
| 65 |
+
|
| 66 |
+
Our sample script will be quite similar to the one used in pretraining and
|
| 67 |
+
finetuning, but with a few additional flags which we describe below.
|
| 68 |
+
|
| 69 |
+
+ `auxiliary_job_mixtures`: This is a comma-separated list of mixtures. A
|
| 70 |
+
separate auxiliary job will be run for each mixture and will replace the gin
|
| 71 |
+
macro `MIXTURE_OR_TASK_NAME`. Note that you need this flag even if you are
|
| 72 |
+
using a custom binary, which does not need a mixture since otherwise no
|
| 73 |
+
auxiliary job will run.
|
| 74 |
+
|
| 75 |
+
+ `auxiliary_job_gin_file`: This is identical to `gin_file`, except it is used
|
| 76 |
+
for the auxiliary job instead of the train job.
|
| 77 |
+
|
| 78 |
+
+ `replace_gin_file`: If True, this auxiliary launcher will not use any of the
|
| 79 |
+
gin files from train job. This is necessary when using a binary different
|
| 80 |
+
from `train.py`, since the top-level functions will not match.
|
| 81 |
+
|
| 82 |
+
+ `auxiliary_job_cell`: The cell in which to run your job. Note that this can
|
| 83 |
+
be different from the training cell.
|
| 84 |
+
|
| 85 |
+
+ `auxiliary_job_platform`: The platform to use for the auxiliary. Note that
|
| 86 |
+
this can be different from the one use for the train job, allowing users to
|
| 87 |
+
use smaller configurations for evaluation than needed for training.
|
| 88 |
+
|
| 89 |
+
+ `auxiliary_job_build_target`: The binary to use for auxiliary job.
|
| 90 |
+
|
| 91 |
+
+ `final_auxiliary_job_steps`: This flag controls how many additional steps to
|
| 92 |
+
take when using the auxiliary job for finetuning. Setting to 0 enables
|
| 93 |
+
continuous evaluation.
|
| 94 |
+
|
| 95 |
+
We provide the sample script below.
|
| 96 |
+
|
| 97 |
+
```sh
|
| 98 |
+
declare -a ARGS=(
|
| 99 |
+
--cell=iz
|
| 100 |
+
--platform=jd=2x2
|
| 101 |
+
--final_auxiliary_job_steps=0
|
| 102 |
+
--replace_gin_file=True
|
| 103 |
+
--auxiliary_job_mixtures=wmt14_enfr_v003
|
| 104 |
+
--auxiliary_job_gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_eval.gin
|
| 105 |
+
--auxiliary_job_cell=iz
|
| 106 |
+
--auxiliary_job_platform=jd=2x2
|
| 107 |
+
--auxiliary_job_build_target_path=//t5x:eval
|
| 108 |
+
--gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_train.gin
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
gxm t5x/google/xm_launch.py "${ARGS[@]}"
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## Example 2: Continuous finetuning job.
|
| 115 |
+
|
| 116 |
+
In this example, we will be pretraining a model on a span corruption task on the
|
| 117 |
+
C4 dataset, and finetuning it on the WMT'14 English-French translation task. As
|
| 118 |
+
before, we will launch a new auxiliary job once every checkpoint is saved.
|
| 119 |
+
However, instead of using the `eval.py` binary, we will use the `train.py`
|
| 120 |
+
binary.
|
| 121 |
+
|
| 122 |
+
### Step 1: Choose a model architecture.
|
| 123 |
+
|
| 124 |
+
We will use the T5-1.1-Base model as in the previous example.
|
| 125 |
+
|
| 126 |
+
### Step 2: Choose a SeqIO Task/Mixture for training and evaluation.
|
| 127 |
+
|
| 128 |
+
For pretraining, we re-use the span coprruption task `c4_v220_span_corruption`
|
| 129 |
+
available in the T5 mixtures `tasks.py` file.
|
| 130 |
+
|
| 131 |
+
### Step 3: Write a Gin config.
|
| 132 |
+
|
| 133 |
+
As before, we need our gin files to contain all the desired macros in them. We
|
| 134 |
+
thus create two new gin files: `base_c4_pretrain.gin` for the train job and
|
| 135 |
+
`base_wmtenfr14_finetune.gin` for the auxiliary job.
|
| 136 |
+
|
| 137 |
+
### Step 4: Launch your experiment.
|
| 138 |
+
|
| 139 |
+
Our script is quite similar to the first example, with the same flags as before
|
| 140 |
+
but with the appropiate changes. The main distinction is that we must change the
|
| 141 |
+
flag `final_auxiliary_job_steps` to be non-zero to start finetuning. We will
|
| 142 |
+
settle for a modest 200 steps for the sake of demonstration (and evaluate every
|
| 143 |
+
100 steps), but users should use larger steps in realistic scenarios. We also
|
| 144 |
+
use `train.py` binary instead of `eval.py`.
|
| 145 |
+
|
| 146 |
+
We provide the sample script below.
|
| 147 |
+
|
| 148 |
+
```sh
|
| 149 |
+
declare -a ARGS=(
|
| 150 |
+
--cell=iz
|
| 151 |
+
--platform=jd=2x2
|
| 152 |
+
--final_auxiliary_job_steps=200
|
| 153 |
+
--replace_gin_file=True
|
| 154 |
+
--auxiliary_job_mixtures=wmt14_enfr_v003
|
| 155 |
+
--auxiliary_job_gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin
|
| 156 |
+
--auxiliary_job_cell=iz
|
| 157 |
+
--auxiliary_job_platform=jd=2x2
|
| 158 |
+
--auxiliary_job_build_target_path=//t5x:train
|
| 159 |
+
--gin_file=t5x/examples/t5/t5_1_1/examples/base_c4_pretrain.gin
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
gxm t5x/google/xm_launch.py "${ARGS[@]}"
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
## Common Gotchas.
|
| 166 |
+
|
| 167 |
+
We outline a few common error patterns that we have encountered.
|
| 168 |
+
|
| 169 |
+
+ **Not passing a value for the `auxiliary_mixtures` flag.** Even if you have
|
| 170 |
+
the desired task in your gin file, or you use a differently named macro, you
|
| 171 |
+
should still pass a value for this flag, since launch script will launch a
|
| 172 |
+
new job per value of this flag.
|
| 173 |
+
|
| 174 |
+
+ **Not setting `replace_gin_file=True` when using a different binary from
|
| 175 |
+
train.py.** This will usually yield an error that there is no `train`
|
| 176 |
+
function.
|
| 177 |
+
|
| 178 |
+
+ **No metrics being logged.** It can be tempting to use gin files usually
|
| 179 |
+
used for evaluation. However, one must ensure that the corresponding SeqIO
|
| 180 |
+
evaluators still log to the tensorboard, otherwise you won’t see the
|
| 181 |
+
metrics.
|
| 182 |
+
|
| 183 |
+
+ **Slow `train_eval`.** While the approach outlined above separates out the
|
| 184 |
+
infer_eval job, it may be that even train_eval is too slow. In these
|
| 185 |
+
situations, we suggest adding the metrics from train_eval into the
|
| 186 |
+
`metrics_fn` argument of the SeqIO task and have them be computed in the
|
| 187 |
+
auxiliary job as well. To do this with teacher forcing, you will have to use
|
| 188 |
+
`train.py` instead of `eval.py`.
|
| 189 |
+
|
| 190 |
+
+ **Using `CHECKPOINT_PATH` rather `INITIAL_CHECKPOINT_PATH`.** For legacy
|
| 191 |
+
reasons, the auxiliary job uses the macro `INITIAL_CHECKPOINT_PATH` rather
|
| 192 |
+
than `CHECKPOINT_PATH` as found in `eval.gin`. Make sure to use the latter
|
| 193 |
+
macro building your gin scripts.
|
| 194 |
+
|
| 195 |
+
+ **Gin macros being ignored when passed through the format
|
| 196 |
+
`gin.{MACRO}={VAL}`.** In the current setup, you must include all gin macros
|
| 197 |
+
in the gin script. Attempting to pass them as additional flags will usually
|
| 198 |
+
not work.
|
| 199 |
+
|
| 200 |
+
+ **Not setting `final_auxiliary_job_steps=0` when performing continuous
|
| 201 |
+
evaluation.** The current parameter controller uses this as a check. When
|
| 202 |
+
this is true, it will replace the `EVAL_OUTPUT_DIR` folder with the current
|
| 203 |
+
`MODEL_DIR`, so that the evaluation metrics are saved in the right place and
|
| 204 |
+
the metrics are showed correctly on the tensorboard.
|
t5x-main/docs/usage/decoding.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Decoding
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
This page outlines the decoding functions that T5X provides out-of-the-box and
|
| 5 |
+
how custom decoding functions can be used for a Transformer model, i.e., an
|
| 6 |
+
instance of
|
| 7 |
+
[`BaseTransformerModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbBaseTransformerModel%5Cb).
|
| 8 |
+
Here we refer to decoding as a process of generating a sequence of items from a
|
| 9 |
+
fixed alphabet (e.g., generating token ids from the vocabulary).
|
| 10 |
+
|
| 11 |
+
There are two major ways to configure the decoding routine. The first method is
|
| 12 |
+
to define a decode function that follows the `DecodeFnCallable` signature. This
|
| 13 |
+
is more restrictive as it enforces the call signature but users don't need to
|
| 14 |
+
modify the model code.
|
| 15 |
+
|
| 16 |
+
The second method is to subclass a model class and override
|
| 17 |
+
`predict_batch_with_aux` method. While this provides more flexibility, it
|
| 18 |
+
requires rewriting the method.
|
| 19 |
+
|
| 20 |
+
## Option 1: defining a decoding function
|
| 21 |
+
|
| 22 |
+
If a desired decoding process can follow `DecodeFnCallable`, it can be
|
| 23 |
+
registered as a private attribute of a
|
| 24 |
+
[`BaseTransformerModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbBaseTransformerModel%5Cb)
|
| 25 |
+
by passing it as a `decode_fn` argument to its constructor.
|
| 26 |
+
|
| 27 |
+
### Decoding function call signature
|
| 28 |
+
|
| 29 |
+
`DecodeFnCallable` has the following call signature
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
It takes in `inputs`, which is an int32 array with a shape `[batch_size,
|
| 33 |
+
max_decode_len]`. This is an input tokens to the decoder. For the standard
|
| 34 |
+
encoder-decoder models like T5, this is initialized as zeros with a desired
|
| 35 |
+
decoding length. The decoding function will populate the array with the sampled
|
| 36 |
+
token ids and return.
|
| 37 |
+
|
| 38 |
+
For a decoder-only architectures such as a Prefix Language Model, `inputs` can
|
| 39 |
+
be a concatenated sequence of "inputs" and "targets" tokens ids.
|
| 40 |
+
|
| 41 |
+
`tokens_to_logits` is a callable that takes in a batch of token ids and the
|
| 42 |
+
current autoregressive cache, performs the forward pass and returns the
|
| 43 |
+
resulting logits resulting and an updated cache. Note that for incremental
|
| 44 |
+
decoding, this function operates with a single token, i.e., the length dimension
|
| 45 |
+
is assumed to be 1.
|
| 46 |
+
|
| 47 |
+
`DecodeFnCallable` is designed to be as general as possible. This results in
|
| 48 |
+
some of the arguments being somewhat generic for a specialized decoding
|
| 49 |
+
algorithm. For example, `num_decodes` refers to the number of decoded samples to
|
| 50 |
+
be returned. In the case of beam search, `num_decodes` corresponds to what is
|
| 51 |
+
commonly known as `beam_size`, with returned sequences sorted by the beam
|
| 52 |
+
scores. For temperature sampling, we perform `num_decodes` *independent*
|
| 53 |
+
sampling procedures with different random seeds and sort them by the log
|
| 54 |
+
probability of the generated sequences.
|
| 55 |
+
|
| 56 |
+
For custom decoding functions, there might be additional arguments. To support
|
| 57 |
+
these, we provide `**kwargs`.
|
| 58 |
+
|
| 59 |
+
Another usage of `**kwargs` is calling `decoding_fn` multiple times without
|
| 60 |
+
recompiling the model. This pattern is used in
|
| 61 |
+
[Prediction Service](https://github.com/google-research/t5x/blob/main/t5x/google/prediction_service/README.md).
|
| 62 |
+
For a compiled model, different values of `alpha` can be passed e.g.,
|
| 63 |
+
`decoder_params = {"alpha": 0.7}` where `decoder_params` is the argument to
|
| 64 |
+
`predict_batch_with_aux`. It is unpacked and passed to `beam_search` function.
|
| 65 |
+
Note that the Prediction Service uses
|
| 66 |
+
[`predict_batch_with_aux`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=func:%5Cbpredict_batch_with_aux%5Cb),
|
| 67 |
+
which is one of the two public methods. This method is useful if auxiliary
|
| 68 |
+
outputs (e.g., scores of the predictions) are to be returned. The other method
|
| 69 |
+
is
|
| 70 |
+
[`predict_batch`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=func:%5Cbpredict_batch%5Cb),
|
| 71 |
+
which simply returns the predictions.
|
| 72 |
+
|
| 73 |
+
### Beam search
|
| 74 |
+
|
| 75 |
+
The following lines can be added to a gin file in order to use
|
| 76 |
+
[beam search](https://github.com/google-research/t5x/blob/main/t5x/decoding.py;l=881;rcl=446762159)
|
| 77 |
+
as a decoding function for an encoder-decoder model.
|
| 78 |
+
|
| 79 |
+
```gin
|
| 80 |
+
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
|
| 81 |
+
models.EncoderDecoderModel.decode_fn = @decoding.beam_search
|
| 82 |
+
decode.beam_search.alpha = 0.6
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Note that we skip the gin boilerplate code such as gin dynamic registration.
|
| 86 |
+
Please refer to [T5X Gin Primer](gin.md) for more details.
|
| 87 |
+
|
| 88 |
+
The beam search behavior is controlled by the arguments passed to `beam_search`.
|
| 89 |
+
We provide details for a few of them below.
|
| 90 |
+
|
| 91 |
+
#### `num_decodes`
|
| 92 |
+
|
| 93 |
+
If `num_decodes` are configured with `gin.register`, it is overridden by the
|
| 94 |
+
value explicitly passed by the caller e.g.,
|
| 95 |
+
`models.EncoderDecoderModel.predict_batch_with_aux`. This is because the
|
| 96 |
+
information about `num_decodes` is needed to prepare the encoder inputs and
|
| 97 |
+
outputs expanded by `num_decodes` times in the batch dimension.
|
| 98 |
+
|
| 99 |
+
We recommend that `num_decodes` be specified *only* in
|
| 100 |
+
`models.EncoderDecoderModel.predict_batch_with_aux`.
|
| 101 |
+
|
| 102 |
+
#### `alpha`
|
| 103 |
+
|
| 104 |
+
This is the brevity penalty introduced in
|
| 105 |
+
[Wu et al. 2016](https://arxiv.org/abs/1609.08144) to penalize short sequences.
|
| 106 |
+
|
| 107 |
+
#### `max_decode_len`
|
| 108 |
+
|
| 109 |
+
For evaluation, we typically don't want to truncate the examples by a specified
|
| 110 |
+
sequence length. Therefore, we dynamically obtain the length information from
|
| 111 |
+
the batch of examples. The default behavior of `seqio.Evaluator` is to use the
|
| 112 |
+
maximum length of a task but, this can be overridden.
|
| 113 |
+
|
| 114 |
+
Since the length information is provided dynamically, we don't set
|
| 115 |
+
`max_decode_len` in gin. Instead we pass the relevant `inputs` array to
|
| 116 |
+
`beam_search` whose length is the dynamically determined maximum length.
|
| 117 |
+
|
| 118 |
+
If `max_decode_len` is explicitly specified via gin, this will override the
|
| 119 |
+
implicitly determined length information unless it is passed by
|
| 120 |
+
`predict_batch_with_aux`.
|
| 121 |
+
|
| 122 |
+
### Temperature sampling
|
| 123 |
+
|
| 124 |
+
[Temperature sampling](https://github.com/google-research/t5x/blob/main/t5x/decoding.py;l=37;rcl=446762159)
|
| 125 |
+
can be used for multiple decoding strategies. The following lines configures
|
| 126 |
+
temperature sampling as a `decode_fn`.
|
| 127 |
+
|
| 128 |
+
```gin
|
| 129 |
+
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
|
| 130 |
+
models.EncoderDecoderModel.decode_fn = @decoding.temperature_sample
|
| 131 |
+
decoding.temperature_sample:
|
| 132 |
+
temperature = 0.5
|
| 133 |
+
topk = 20
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
Similar specification can be used for other model types by replacing
|
| 137 |
+
`models.EncoderDecoderModel` with the relevant model class, e.g.
|
| 138 |
+
`models.PrefixLanguageModel`.
|
| 139 |
+
|
| 140 |
+
The sampling behavior is controlled by the arguments passed to
|
| 141 |
+
`temperature_sample`. We provide details for a few of them below.
|
| 142 |
+
|
| 143 |
+
#### `temperature`
|
| 144 |
+
|
| 145 |
+
A probabilistic model outputs a probability distribution over a pre-defined
|
| 146 |
+
alphabet. For example, a language model outputs *logits*, which are unnormalized
|
| 147 |
+
probability values for each item in the vocabulary. We use a language model as a
|
| 148 |
+
running example. A sampling process involves *sampling* from the predicted
|
| 149 |
+
distribution one item at a time conditioned on the previously generated items
|
| 150 |
+
until a given number of items are generated or a sentinel token that represents
|
| 151 |
+
the end of sequence is generated.
|
| 152 |
+
|
| 153 |
+
Temperature modifies the unnormalized probability distribution at each step. For
|
| 154 |
+
each item $$i$$ in the vocabulary, its probability predicted by the model is
|
| 155 |
+
given by
|
| 156 |
+
|
| 157 |
+
$$p_i \propto \exp\left(\frac{x_i}{T} \right)$$
|
| 158 |
+
|
| 159 |
+
where $$T$$ is the temperature and $$x_i$$ is the logits value corresponding to
|
| 160 |
+
item $$i$$. As $$T \to 0$$, the distribution puts all probability mass to the
|
| 161 |
+
item with the highest probability. In other words, the sampling process becomes
|
| 162 |
+
a greedy search.
|
| 163 |
+
|
| 164 |
+
In the other extreme, as $$T \to \infty$$, the predicted distribution becomes
|
| 165 |
+
uniform.
|
| 166 |
+
|
| 167 |
+
#### `topk`
|
| 168 |
+
|
| 169 |
+
By specifying strictly positive integer value for `topk`, the sampling process
|
| 170 |
+
in each step is limited to the `k` items with highest probabilities. `topk` also
|
| 171 |
+
uses `temperature` to modify the logits corresponding to the top `k` items.
|
| 172 |
+
|
| 173 |
+
#### `topp`
|
| 174 |
+
|
| 175 |
+
By specifying non-zero positive float value for `topp`, the sampling process is
|
| 176 |
+
limited to a subset of the vocabulary $$V^{(p)} \subset V$$, which is defined by
|
| 177 |
+
the smallest set such that
|
| 178 |
+
|
| 179 |
+
$$\sum_{i \in V^{(p)}} p_i \ge p$$
|
| 180 |
+
|
| 181 |
+
where $$p_i$$ is the conditional distribution at each time step for item $$i$$.
|
| 182 |
+
This is called "Nucleus sampling", which was introduced by
|
| 183 |
+
[Holtzman et al. ICLR 2020](https://openreview.net/forum?id=rygGQyrFvH).
|
| 184 |
+
|
| 185 |
+
IMPORTANT: Only one of `topk` or `topp` can be used.
|
| 186 |
+
|
| 187 |
+
## Option 2: subclassing a model class
|
| 188 |
+
|
| 189 |
+
If `DecodeFnCallable` is not flexible enough for your custom decoding function,
|
| 190 |
+
you can subclass the model class and override `predict_batch_with_aux` method.
|
| 191 |
+
While the model class can be any instance of
|
| 192 |
+
[`BaseTransformerModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbBaseTransformerModel%5Cb),
|
| 193 |
+
we recommend that you subclass the existing models such as
|
| 194 |
+
[`EncoderDecoderModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbEncoderDecoderModel%5Cb)
|
| 195 |
+
and only override `predict_batch_with_aux` method.
|
| 196 |
+
|
| 197 |
+
`predict_batch_with_aux` method also has a required call signature, but it is
|
| 198 |
+
significantly more flexible. It should return a tuple of predicted sequence
|
| 199 |
+
array and auxiliary outputs such as score.
|
t5x-main/docs/usage/eval.md
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluating a Model
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
## Introduction
|
| 5 |
+
|
| 6 |
+
This page outlines the steps to evaluate a model with T5X on downstream tasks
|
| 7 |
+
defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md).
|
| 8 |
+
|
| 9 |
+
Refer to this tutorial when you have an existing model that you want to
|
| 10 |
+
evaluate. If you would like to fine-tune your model before evaluation, please
|
| 11 |
+
refer to the [fine-tuning](finetune.md) tutorial. You can run evals as part of
|
| 12 |
+
your fine-tuning run as well.
|
| 13 |
+
|
| 14 |
+
## Overview
|
| 15 |
+
|
| 16 |
+
Evaluating a model with T5X consists of the following steps:
|
| 17 |
+
|
| 18 |
+
1. Choose the model to evaluate.
|
| 19 |
+
1. Choose the SeqIO Task/Mixture to evaluate the model on.
|
| 20 |
+
1. Write a Gin file that configures the model, SeqIO Task/Mixture and other
|
| 21 |
+
details of your eval run.
|
| 22 |
+
1. Launch your experiment locally or on XManager.
|
| 23 |
+
1. Monitor your experiment and parse metrics.
|
| 24 |
+
|
| 25 |
+
These steps are explained in detail in the following sections. An example run
|
| 26 |
+
that evaluates a fine-tuned T5-1.1-Small checkpoint on the
|
| 27 |
+
[(Open Domain) Natural Questions benchmark](https://ai.google.com/research/NaturalQuestions/)
|
| 28 |
+
is also showcased.
|
| 29 |
+
|
| 30 |
+
## Step 1: Choose a model
|
| 31 |
+
|
| 32 |
+
To evaluate a model, you need a Gin config file that defines the model params,
|
| 33 |
+
and the model checkpoint to load from. For this example, a T5-1.1-Small model
|
| 34 |
+
fine-tuned on the
|
| 35 |
+
[`natural_questions_open_test`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=141&rcl=370261021)
|
| 36 |
+
SeqIO Task will be used:
|
| 37 |
+
|
| 38 |
+
+ Model checkpoint -
|
| 39 |
+
[`cbqa/small_ssm_nq/model.ckpt-1110000`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/cbqa/small_ssm_nq/)
|
| 40 |
+
+ Model Gin file -
|
| 41 |
+
[`t5x/configs/models/t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
|
| 42 |
+
|
| 43 |
+
If you would like to fine-tune your model before evaluation, please follow the
|
| 44 |
+
[fine-tuning](finetune.md) tutorial, and continue to Step 2. A list of all
|
| 45 |
+
available pre-trained models (with model checkpoints and Gin config files) are
|
| 46 |
+
available in the [Models](https://github.com/google-research/t5x/blob/main/docs/models.md) documentation.
|
| 47 |
+
|
| 48 |
+
## Step 2: Choose a SeqIO Task/Mixture
|
| 49 |
+
|
| 50 |
+
A SeqIO Task encapsulates the data source, the preprocessing logic to be
|
| 51 |
+
performed on the data before querying the model, the postprocessing logic to be
|
| 52 |
+
performed on model outputs, and the metrics to be computed given the
|
| 53 |
+
postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks
|
| 54 |
+
and enables fine-tuning a model on multiple Tasks simultaneously.
|
| 55 |
+
|
| 56 |
+
Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/),
|
| 57 |
+
[SuperGLUE](https://super.gluebenchmark.com/),
|
| 58 |
+
[WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate),
|
| 59 |
+
[SQUAD](https://rajpurkar.github.io/SQuAD-explorer/),
|
| 60 |
+
[CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been
|
| 61 |
+
implemented as SeqIO Tasks/Mixtures and can be used directly. These
|
| 62 |
+
Tasks/Mixtures are defined in
|
| 63 |
+
[`t5/data/tasks.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py) and
|
| 64 |
+
[`t5/data/mixtures.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/mixtures.py).
|
| 65 |
+
|
| 66 |
+
For the example run, you will evaluate the model on the Natural Questions
|
| 67 |
+
benchmark, which has been implemented as the `natural_questions_open` Task in
|
| 68 |
+
[`/third_party/google_research/google_research/t5_closed_book_qa/t5_cbqa/tasks.py`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=98&rcl=370261021).
|
| 69 |
+
Here's an example of a single row of preprocessed data from this Task:
|
| 70 |
+
|
| 71 |
+
```python
|
| 72 |
+
{
|
| 73 |
+
'inputs_pretokenized': 'nq question: what was the main motive of salt march',
|
| 74 |
+
'inputs': [3, 29, 1824, 822, 10, 125, 47, 8, 711, 10280, 13, 3136, 10556, 1]
|
| 75 |
+
'targets_pretokenized': 'challenge to British authority',
|
| 76 |
+
'targets': [1921, 12, 2390, 5015, 1],
|
| 77 |
+
'answers': ['challenge to British authority']
|
| 78 |
+
}
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## Step 3: Write a Gin Config
|
| 82 |
+
|
| 83 |
+
After choosing the model and SeqIO Task/Mixture for your run, the next step is
|
| 84 |
+
to configure your run using Gin. If you're not familiar with Gin, reading the
|
| 85 |
+
[T5X Gin Primer](gin.md) is recommended.
|
| 86 |
+
|
| 87 |
+
T5X provides a Gin file that configures the T5X eval job (located at
|
| 88 |
+
[`t5x/configs/runs/eval.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/eval.gin)),
|
| 89 |
+
and expects a few params from you. These params can be specified in a separate
|
| 90 |
+
Gin file, or via commandline flags. Following are the required params:
|
| 91 |
+
|
| 92 |
+
+ `CHECKPOINT_PATH`: This is the path to the model checkpoint (from Step 1).
|
| 93 |
+
For the example run, set this to
|
| 94 |
+
`'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'`.
|
| 95 |
+
+ `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run eval
|
| 96 |
+
on (from Step 2). For the example run, set this to
|
| 97 |
+
`'natural_questions_open'`.
|
| 98 |
+
+ `EVAL_OUTPUT_DIR`: A path to write eval outputs to. When launching using
|
| 99 |
+
XManager, this path is automatically set and can be accessed from the
|
| 100 |
+
XManager Artifacts page. When running locally using Blaze, you can
|
| 101 |
+
explicitly pass a directory using a flag. Launch commands are provided in
|
| 102 |
+
the next step.
|
| 103 |
+
|
| 104 |
+
In addition to the above params, you will need to import
|
| 105 |
+
[`eval.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/eval.gin) and the
|
| 106 |
+
Gin file for the model, which for the example run is
|
| 107 |
+
[`t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
|
| 108 |
+
|
| 109 |
+
```gin
|
| 110 |
+
include 'runs/eval.gin'
|
| 111 |
+
include 'models/t5_small.gin'
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
Note that the `include` statements use relative paths in this example. You will
|
| 115 |
+
pass an appropriate `gin_search_paths` flag to locate these files when launching
|
| 116 |
+
your run. Absolute paths to Gin files can also be used, e.g.
|
| 117 |
+
|
| 118 |
+
```gin
|
| 119 |
+
include 't5x/configs/runs/eval.gin'
|
| 120 |
+
include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin'
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
You will also need to import the Python module(s) that register SeqIO Tasks and
|
| 124 |
+
Mixtures used in your run. For the example run, we add `import
|
| 125 |
+
google_research.t5_closed_book_qa.t5_cbqa.tasks`
|
| 126 |
+
since it is where 'glue_v002_proportional' is registered.
|
| 127 |
+
|
| 128 |
+
If you choose a module that is not included as a dependency in the T5X trainer
|
| 129 |
+
[binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=76;rcl=398627055), or if you
|
| 130 |
+
have defined your gin config file in a location other than the
|
| 131 |
+
[T5X config directory](https://github.com/google-research/t5x/blob/main/t5x/configs/), you will
|
| 132 |
+
need to follow the instructions in the
|
| 133 |
+
[Advanced Topics section](#custom-t5x-binaries) to link in the custom gin file
|
| 134 |
+
and/or task definition.
|
| 135 |
+
|
| 136 |
+
Note that for most common Task/Mixtures, such as the `glue_v002_proportional`
|
| 137 |
+
used in this tutorial, the necessary modules are already included. It is also
|
| 138 |
+
possible to skip writing a Gin file and instead pass the params as flags when
|
| 139 |
+
launching the eval job (see instructions in Step 4).
|
| 140 |
+
|
| 141 |
+
Finally, your Gin file should look like this:
|
| 142 |
+
|
| 143 |
+
```gin
|
| 144 |
+
include 't5x/configs/runs/eval.gin'
|
| 145 |
+
include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin'
|
| 146 |
+
|
| 147 |
+
# Register necessary SeqIO Tasks/Mixtures.
|
| 148 |
+
import google_research.t5_closed_book_qa.t5_cbqa.tasks
|
| 149 |
+
|
| 150 |
+
CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'
|
| 151 |
+
MIXTURE_OR_TASK_NAME = 'natural_questions_open'
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
See
|
| 155 |
+
[`t5_1_1_small_cbqa_natural_questions.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/examples/eval/t5_1_1_small_cbqa_natural_questions.gin)
|
| 156 |
+
for this example.
|
| 157 |
+
|
| 158 |
+
In this example, we run the evaluation on one checkpoint. It is common to
|
| 159 |
+
evaluate with multiple checkpoints. We provide an easy way to do so *without*
|
| 160 |
+
having to recompile the model graph for each checkpoints. This is simply done by
|
| 161 |
+
adding `utils.RestoreCheckpointConfig.mode = "all"` to a gin file. Our
|
| 162 |
+
`t5x/configs/runs/eval.gin` uses "specific" mode.
|
| 163 |
+
|
| 164 |
+
## Step 4: Launch your experiment
|
| 165 |
+
|
| 166 |
+
To launch your experiment locally (for debugging only; larger checkpoints may
|
| 167 |
+
cause issues), run the following on commandline:
|
| 168 |
+
|
| 169 |
+
```sh
|
| 170 |
+
EVAL_OUTPUT_DIR="/tmp/model-eval/"
|
| 171 |
+
python -m t5x.eval_unfragmented \
|
| 172 |
+
--gin_file=t5x/google/examples/flaxformer_t5/configs/examples/eval/t5_1_1_small_cbqa_natural_questions.gin \
|
| 173 |
+
--gin.EVAL_OUTPUT_DIR=\"${EVAL_OUTPUT_DIR}\" \
|
| 174 |
+
--alsologtostderr
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
Note that relative paths can be used to locate the gin files. For that, multiple
|
| 178 |
+
comma-separated paths can be passed to the `gin_search_paths` flag, and these
|
| 179 |
+
paths should contain all Gin files used or included in your experiment.
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
You can have a look inside
|
| 183 |
+
[`eval.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/eval.gin) to see
|
| 184 |
+
other useful parameters that it is possible to pass in, including dataset split,
|
| 185 |
+
batch size, and random seed.
|
| 186 |
+
|
| 187 |
+
## Step 5: Monitor your experiment and parse metrics
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
After evaluation has completed, you can parse metrics into CSV format using the
|
| 191 |
+
following script:
|
| 192 |
+
|
| 193 |
+
```sh
|
| 194 |
+
EVAL_OUTPUT_DIR= # from Step 4 if running locally, from XManager Artifacts otherwise
|
| 195 |
+
VAL_DIR="$EVAL_OUTPUT_DIR/inference_eval"
|
| 196 |
+
python -m t5.scripts.parse_tb \
|
| 197 |
+
--summary_dir="$VAL_DIR" \
|
| 198 |
+
--seqio_summaries \
|
| 199 |
+
--out_file="$VAL_DIR/results.csv" \
|
| 200 |
+
--alsologtostderr
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
## Next Steps
|
| 204 |
+
|
| 205 |
+
Now that you have successfully evaluated a model on the Natural Questions
|
| 206 |
+
benchmark, here are some topics you might want to explore next:
|
| 207 |
+
|
| 208 |
+
+ [Running inference on a model.](infer.md)
|
| 209 |
+
+ [Fine-tuning a model.](finetune.md)
|
| 210 |
+
+ [Training a model from scratch.](pretrain.md)
|
| 211 |
+
|
| 212 |
+
We also touch upon a few advanced topics related to evaluations below that might
|
| 213 |
+
be useful, especially when customizing your eval job.
|
| 214 |
+
|
| 215 |
+
## Advanced Topics
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
### Defining a custom SeqIO Task/Mixture to evaluate on {.no-toc}
|
| 219 |
+
|
| 220 |
+
Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md).
|
| 221 |
+
|
| 222 |
+
### Defining a custom metric to evaluate
|
| 223 |
+
|
| 224 |
+
The best way to define a custom metric is to define a new SeqIO Task/Mixture
|
| 225 |
+
that contains this custom metric. Please refer to the SeqIO Documentation on
|
| 226 |
+
[custom metrics](https://github.com/google/seqio/blob/main/README.md#metrics).
|
t5x-main/docs/usage/finetune.md
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fine Tuning a Model
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
## Introduction
|
| 5 |
+
|
| 6 |
+
This page outlines the steps to fine-tune an existing pre-trained model with T5X
|
| 7 |
+
on common downstream tasks defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md). This is one of
|
| 8 |
+
the simplest and most common use cases of T5X. If you're new to T5X, this
|
| 9 |
+
tutorial is the recommended starting point.
|
| 10 |
+
|
| 11 |
+
## Overview
|
| 12 |
+
|
| 13 |
+
Fine-tuning a model with T5X consists of the following steps:
|
| 14 |
+
|
| 15 |
+
1. Choose the pre-trained model to fine-tune.
|
| 16 |
+
2. Choose the SeqIO Task/Mixture to fine-tune the model on.
|
| 17 |
+
3. Write a Gin file that configures the pre-trained model, SeqIO Task/Mixture
|
| 18 |
+
and other details of your fine-tuning run.
|
| 19 |
+
4. Launch your experiment locally or on XManager.
|
| 20 |
+
5. Monitor your experiment and parse metrics.
|
| 21 |
+
|
| 22 |
+
These steps are explained in detail in the following sections. An example run
|
| 23 |
+
that fine-tunes a T5-small checkpoint on WMT14 English to German translation
|
| 24 |
+
benchmark is also showcased.
|
| 25 |
+
|
| 26 |
+
## Step 1: Choose a pre-trained model
|
| 27 |
+
|
| 28 |
+
To use a pre-trained model, you need a Gin config file that defines the model
|
| 29 |
+
params, and the model checkpoint to load from. For your convenience, TensorFlow
|
| 30 |
+
checkpoints and Gin configs for common T5 pre-trained models have been made
|
| 31 |
+
available for use in T5X. A list of all the available pre-trained models (with
|
| 32 |
+
model checkpoints and Gin config files) are available in the
|
| 33 |
+
[Models](https://github.com/google-research/t5x/blob/main/docs/models.md) documentation.
|
| 34 |
+
|
| 35 |
+
For the example run, you will use the T5 1.1 Small model. The Gin file for this
|
| 36 |
+
model is located at
|
| 37 |
+
[`/t5x/examples/t5/t5_1_1/small.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin),
|
| 38 |
+
and the checkpoint is located at
|
| 39 |
+
[`gs://t5-data/pretrained_models/t5x/t5_1_1_small`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_small).
|
| 40 |
+
|
| 41 |
+
## Step 2: Choose a SeqIO Task/Mixture
|
| 42 |
+
|
| 43 |
+
A SeqIO Task encapsulates the data source, the preprocessing logic to be
|
| 44 |
+
performed on the data before querying the model, the postprocessing logic to be
|
| 45 |
+
performed on model outputs, and the metrics to be computed given the
|
| 46 |
+
postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks
|
| 47 |
+
and enables fine-tuning a model on multiple Tasks simultaneously.
|
| 48 |
+
|
| 49 |
+
### Standard Tasks
|
| 50 |
+
|
| 51 |
+
Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/),
|
| 52 |
+
[SuperGLUE](https://super.gluebenchmark.com/),
|
| 53 |
+
[WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate),
|
| 54 |
+
[SQUAD](https://rajpurkar.github.io/SQuAD-explorer/),
|
| 55 |
+
[CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been
|
| 56 |
+
implemented as SeqIO Tasks/Mixtures and can be used directly. These
|
| 57 |
+
Tasks/Mixtures are defined in
|
| 58 |
+
[`third_party/py/t5/data/tasks.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py)
|
| 59 |
+
and
|
| 60 |
+
[`third_party/py/t5/data/mixtures.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/mixtures.py).
|
| 61 |
+
|
| 62 |
+
For the example run, you will fine-tune the model on the WMT14 English to German
|
| 63 |
+
translation benchmark, which has been implemented as the
|
| 64 |
+
[`wmt_t2t_ende_v003`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py;l=209;rcl=417815592)
|
| 65 |
+
Task.
|
| 66 |
+
|
| 67 |
+
### Custom Tasks
|
| 68 |
+
|
| 69 |
+
It is also possible to define your own custom task. See the
|
| 70 |
+
[SeqIO documentation](https://github.com/google/seqio/blob/main/README.md) for how to do this. As a note, Tasks
|
| 71 |
+
defined using the
|
| 72 |
+
[old T5 codebase](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/dataset_providers.py)
|
| 73 |
+
may also be used by T5X. If using a custom Task, you will need to follow the
|
| 74 |
+
instructions in the [Advanced Topics section](#custom-t5x-binaries) at the end
|
| 75 |
+
of this tutorial to make sure the module containing your task is included.
|
| 76 |
+
|
| 77 |
+
When defining a custom task, you have the option to cache it on disk before
|
| 78 |
+
fine-tuning. The instructions for this are
|
| 79 |
+
[here](https://github.com/google/seqio/blob/main/README.md#optional-offline-caching). Caching may improve
|
| 80 |
+
performance for tasks with expensive pre-processing. By default, T5X expects
|
| 81 |
+
tasks to be cached. To finetune on a task that has not been cached, set
|
| 82 |
+
`--gin.USE_CACHED_TASKS=False`.
|
| 83 |
+
|
| 84 |
+
## Step 3: Write a Gin Config
|
| 85 |
+
|
| 86 |
+
After choosing the pre-trained model and SeqIO Task/Mixture for your run, the
|
| 87 |
+
next step is to configure your run using Gin. If you're not familiar with Gin,
|
| 88 |
+
reading the [T5X Gin Primer](gin.md) is recommended.
|
| 89 |
+
|
| 90 |
+
T5X provides a Gin file that configures the T5X trainer for fine-tuning (located
|
| 91 |
+
at
|
| 92 |
+
[`t5x/configs/runs/finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin)),
|
| 93 |
+
and expects a few params from you. These params can be specified in a separate
|
| 94 |
+
Gin file, or via commandline flags. Following are the required params:
|
| 95 |
+
|
| 96 |
+
+ `INITIAL_CHECKPOINT_PATH`: This is the path to the pre-trained checkpoint
|
| 97 |
+
(from Step 1). For the example run, set this to
|
| 98 |
+
`'gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000'`.
|
| 99 |
+
+ `TRAIN_STEPS`: Number of fine-tuning steps. This includes the number of
|
| 100 |
+
steps that the model was pre-trained for, so make sure to add the step
|
| 101 |
+
number from the `INITIAL_CHECKPOINT_PATH`. For the example run, to fine-tune
|
| 102 |
+
for `20_000` steps, set this to `1_020_000`, since the initial checkpoint is
|
| 103 |
+
the `1_000_000`th step.
|
| 104 |
+
+ `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run (from
|
| 105 |
+
Step 2). For the example run, set this to `'wmt_t2t_ende_v003'`.
|
| 106 |
+
+ `TASK_FEATURE_LENGTHS`: This is a dict mapping feature key to maximum int
|
| 107 |
+
length for that feature. After preprocessing, features are truncated to the
|
| 108 |
+
provided value. For the example run, set this to `{'inputs': 256, 'targets':
|
| 109 |
+
256}`.
|
| 110 |
+
+ `MODEL_DIR`: A path to write fine-tuned checkpoints to. When launching using
|
| 111 |
+
XManager, this path is automatically set and can be accessed from the
|
| 112 |
+
XManager Artifacts page. When running locally using Blaze, you can
|
| 113 |
+
explicitly pass a directory using a flag. Launch commands are provided in
|
| 114 |
+
the next step.
|
| 115 |
+
+ `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained
|
| 116 |
+
using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should
|
| 117 |
+
be set to `pretraining batch_size` * `pretrained target_token_length`. For
|
| 118 |
+
T5 and T5.1.1: `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.
|
| 119 |
+
|
| 120 |
+
In addition to the above params, you will need to include
|
| 121 |
+
[`finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin)
|
| 122 |
+
and the Gin file for the pre-trained model, which for the example run is
|
| 123 |
+
[`t5_1_1/small.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin).
|
| 124 |
+
|
| 125 |
+
```gin
|
| 126 |
+
include 't5x/configs/runs/finetune.gin'
|
| 127 |
+
include 't5x/examples/t5/t5_1_1/small.gin'
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
You will also need to import the Python module(s) that register SeqIO Tasks and
|
| 131 |
+
Mixtures used in your run. For the example run, we add `import t5.data.tasks`
|
| 132 |
+
since it is where `wmt_t2t_ende_v003` is registered.
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
Finally, your Gin file should look like this:
|
| 136 |
+
|
| 137 |
+
```gin
|
| 138 |
+
include 't5x/configs/runs/finetune.gin'
|
| 139 |
+
include 't5x/examples/t5/t5_1_1/small.gin'
|
| 140 |
+
|
| 141 |
+
# Register necessary SeqIO Tasks/Mixtures.
|
| 142 |
+
import t5.data.tasks
|
| 143 |
+
|
| 144 |
+
MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003"
|
| 145 |
+
TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256}
|
| 146 |
+
TRAIN_STEPS = 1_020_000 # 1000000 pre-trained steps + 20000 fine-tuning steps.
|
| 147 |
+
DROPOUT_RATE = 0.0
|
| 148 |
+
INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000"
|
| 149 |
+
LOSS_NORMALIZING_FACTOR = 233472
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
See
|
| 153 |
+
[`t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin)
|
| 154 |
+
for this example.
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
## Step 4: Launch your experiment
|
| 158 |
+
|
| 159 |
+
To launch your experiment locally (for debugging only; larger checkpoints may
|
| 160 |
+
cause issues), run the following on commandline:
|
| 161 |
+
|
| 162 |
+
```sh
|
| 163 |
+
MODEL_DIR="/tmp/finetune-model/"
|
| 164 |
+
python -m t5x.train_unfragmented \
|
| 165 |
+
--gin_file=t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin \
|
| 166 |
+
--gin.MODEL_DIR=\"${MODEL_DIR}\" \
|
| 167 |
+
--alsologtostderr
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
Note that multiple comma-separated paths can be passed to the `gin_search_paths`
|
| 171 |
+
flag, and these paths should contain all Gin files used or included in your
|
| 172 |
+
experiment.
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
After fine-tuning has completed, you can parse metrics into CSV format using the
|
| 176 |
+
following script:
|
| 177 |
+
|
| 178 |
+
```sh
|
| 179 |
+
MODEL_DIR= # from Step 4 if running locally, from XManager Artifacts otherwise
|
| 180 |
+
VAL_DIR="$MODEL_DIR/inference_eval"
|
| 181 |
+
python -m t5.scripts.parse_tb \
|
| 182 |
+
--summary_dir="$VAL_DIR" \
|
| 183 |
+
--seqio_summaries \
|
| 184 |
+
--out_file="$VAL_DIR/results.csv" \
|
| 185 |
+
--alsologtostderr
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
### Metric Explanations
|
| 189 |
+
|
| 190 |
+
By default, t5x logs many metrics to TensorBoard, many of these seem similar but
|
| 191 |
+
have important distinctions.
|
| 192 |
+
|
| 193 |
+
The first two graphs you will see are the `accuracy` and `cross_ent_loss`
|
| 194 |
+
graphs. These are the *token-level teacher-forced* accuracy and cross entropy
|
| 195 |
+
loss respectively. Each of these graphs can have multiple curves on them. The
|
| 196 |
+
first curve is the `train` curve. This is calculated as a running sum than is
|
| 197 |
+
then normalized over the whole training set. The second class of curves have the
|
| 198 |
+
form `training_eval/${task_name}`. These curves are created by running a subset
|
| 199 |
+
(controlled by the `eval_steps` parameter of the main train function) of the
|
| 200 |
+
validation split of `${task_name}` through the model and calculating these
|
| 201 |
+
metrics using teacher-forcing. These graphs can commonly be used to find
|
| 202 |
+
"failure to learn" cases and as a warning sign of overfitting, but these are
|
| 203 |
+
often not the final metrics one would report on.
|
| 204 |
+
|
| 205 |
+
The second set of graphs are the ones under the collapsible `eval` section in
|
| 206 |
+
TensorBoard. These graphs are created based on the `metric_fns` defined in the
|
| 207 |
+
SeqIO task. The curves on these graphs have the form
|
| 208 |
+
`inference_eval/${task_name}`. Values are calculated by running the whole
|
| 209 |
+
validation split through the model in inference mode, commonly auto-regressive
|
| 210 |
+
decoding or output scoring. Most likely these are the metrics that will be
|
| 211 |
+
reported.
|
| 212 |
+
|
| 213 |
+
More information about the configuration of the datasets used for these
|
| 214 |
+
different metrics can be found [here](#train-train-eval-and-infer-eval).
|
| 215 |
+
|
| 216 |
+
In summary, the metric you actually care about most likely lives under the
|
| 217 |
+
`eval` tab rather, than in the `accuracy` graph.
|
| 218 |
+
|
| 219 |
+
## Next Steps
|
| 220 |
+
|
| 221 |
+
Now that you have successfully fine-tuned a pre-trained model on WMT, here are
|
| 222 |
+
some topics you might want to explore next:
|
| 223 |
+
|
| 224 |
+
+ [Evaluating a fine-tuned model.](eval.md)
|
| 225 |
+
+ [Running inference on a fine-tuned model.](infer.md)
|
| 226 |
+
+ [Training a model from scratch.](pretrain.md)
|
| 227 |
+
|
| 228 |
+
We also touch upon a few advanced topics related to fine-tuning below that might
|
| 229 |
+
be useful, especially when customizing your fine-tuning job.
|
| 230 |
+
|
| 231 |
+
## Advanced Topics
|
| 232 |
+
|
| 233 |
+
### `train`, `train_eval` and `infer_eval` {.no-toc}
|
| 234 |
+
|
| 235 |
+
A
|
| 236 |
+
[`DatasetConfig`](https://github.com/google-research/t5x/blob/main/t5x/utils.py?l=113&rcl=375475889)
|
| 237 |
+
object is used to configure loading SeqIO Tasks/Mixtures for training and eval.
|
| 238 |
+
If you take a closer look at
|
| 239 |
+
[`runs/finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin),
|
| 240 |
+
you will see that there are three `DatasetConfig` objects defined and passed to
|
| 241 |
+
the train function: `train_dataset_cfg`, `train_eval_dataset_cfg`,
|
| 242 |
+
`infer_eval_dataset_cfg`. Here's a brief description of these configs:
|
| 243 |
+
|
| 244 |
+
+ `train`: This configures the Task/Mixture that the model will be fine-tuned
|
| 245 |
+
on.
|
| 246 |
+
+ `train_eval`: This configures the Task/Mixture that is used to compute
|
| 247 |
+
training metrics on the eval split, e.g. perplexity. These metrics are
|
| 248 |
+
defined in the
|
| 249 |
+
[`Model`](https://github.com/google-research/t5x/blob/main/t5x/models.py;l=257-267;rcl=394045248)
|
| 250 |
+
class and the eval fn is located
|
| 251 |
+
[here](https://github.com/google-research/t5x/blob/main/t5x/trainer.py;l=257;rcl=398487394).
|
| 252 |
+
+ `infer_eval`: This configures the Task/Mixture that is used to compute
|
| 253 |
+
metrics on inferred model outputs (e.g., comparing decoded model outputs and
|
| 254 |
+
targets). These metrics are defined in the SeqIO Task/Mixture and the eval
|
| 255 |
+
fn is located
|
| 256 |
+
[here](https://github.com/google/seqio/tree/main/seqio/evaluation.py?l=423&rcl=373643592)
|
| 257 |
+
|
| 258 |
+
### Using separate SeqIO Tasks/Mixtures for fine-tuning and eval {.no-toc}
|
| 259 |
+
|
| 260 |
+
Commonly, the same SeqIO Task/Mixture is used for training and eval. It is set
|
| 261 |
+
by the `MIXTURE_OR_TASK_NAME` macro in your fine-tune Gin file from Step 3
|
| 262 |
+
above, and is passed to `train_dataset_cfg`, `train_eval_dataset_cfg`,
|
| 263 |
+
`infer_eval_dataset_cfg`. The `train` split is used for training and the
|
| 264 |
+
`validation` split is used for evals. However, you can override these params in
|
| 265 |
+
your fine-tune Gin config. For example, if you want to fine-tune on all GLUE
|
| 266 |
+
tasks but evaluate only on GLUE STS benchmark, you can override the SeqIO
|
| 267 |
+
Task/Mixture used for `infer_eval` in your fine-tune Gin file as follows:
|
| 268 |
+
|
| 269 |
+
```gin
|
| 270 |
+
include 'runs/finetune.gin'
|
| 271 |
+
include 'models/t5_small.gin'
|
| 272 |
+
|
| 273 |
+
MIXTURE_OR_TASK_NAME = 'glue_v002_proportional'
|
| 274 |
+
MIXTURE_OR_TASK_MODULE = 't5.data.tasks'
|
| 275 |
+
TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 84}
|
| 276 |
+
TRAIN_STEPS = 1_500_000 # includes 1_000_000 pretrain steps
|
| 277 |
+
INITIAL_CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/t5x/t5_small/checkpoint_1000000'
|
| 278 |
+
infer_eval/utils.DatasetConfig.mixture_or_task_name = 'glue_stsb_v002'
|
| 279 |
+
```
|
| 280 |
+
|
| 281 |
+
Other params in `finetune.gin` can be overridden in the same way.
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
### Defining a custom SeqIO Task/Mixture to fine-tune on {.no-toc}
|
| 285 |
+
|
| 286 |
+
Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md).
|
t5x-main/docs/usage/gin.md
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gin Primer
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
[Gin](https://github.com/google/gin-config/blob/main/README.md) is a lightweight configuration framework for Python,
|
| 5 |
+
based on dependency injection. While T5X does not employ gin in its core
|
| 6 |
+
libraries, it is used to configure runs of the `train`, `eval`, and `infer`
|
| 7 |
+
scripts. This usage is a bit different (and more limited) than how gin is
|
| 8 |
+
typically applied, so this primer should be useful even for those who may be
|
| 9 |
+
familiar with gin from other libaries (e.g., T5 or Mesh TensorFlow).
|
| 10 |
+
|
| 11 |
+
Nevertheless, you may still find it helpful to refer to the
|
| 12 |
+
[gin documentation](https://github.com/google/gin-config/blob/main/README.md) for more background.
|
| 13 |
+
|
| 14 |
+
[TOC]
|
| 15 |
+
|
| 16 |
+
## Gin in T5X Scripts
|
| 17 |
+
|
| 18 |
+
Rather than plumbing run arguments and hyperparameters through via limited set
|
| 19 |
+
of command-line flags or a flat configuration schema, T5X's gin integration
|
| 20 |
+
allows you to parameterize the top-level run functions (`train`, `evaluate`, and
|
| 21 |
+
`infer`) as well as any object or function that is passed to them. This enables
|
| 22 |
+
a vast amount of flexibility over your runs without needing to modify any code
|
| 23 |
+
within the core T5X library.
|
| 24 |
+
|
| 25 |
+
For example, you can implement a Python class in your own codebase (e.g., a
|
| 26 |
+
custom model or trainer) and use gin to pass an instance of it to the T5X XM
|
| 27 |
+
launcher without having to fork any code. Previously you needed to implement
|
| 28 |
+
every experimental idea in the core library (no matter how widely used it would
|
| 29 |
+
be) and add a ConfigDict flag to enable/disable it, resulting in significant
|
| 30 |
+
code debt over time.
|
| 31 |
+
|
| 32 |
+
On the other hand, gin can sometimes be too powerful, allowing users the ability
|
| 33 |
+
to bind arguments throughout a codebase, which makes it difficult or impossible
|
| 34 |
+
to update "private" internal interfaces. However, by limiting configurability to
|
| 35 |
+
a single top-level function and its arguments we can better control the
|
| 36 |
+
configurable surface to public interfaces and user-owned code, and also avoid
|
| 37 |
+
unintended side effects.
|
| 38 |
+
|
| 39 |
+
### An Example
|
| 40 |
+
|
| 41 |
+
Let's look at the `evaluate` call signature from
|
| 42 |
+
[eval.py](https://github.com/google-research/t5x/blob/main/t5x/eval.py) as an example:
|
| 43 |
+
|
| 44 |
+
```py
|
| 45 |
+
def evaluate(*,
|
| 46 |
+
model: models.BaseModel,
|
| 47 |
+
dataset_cfg: utils.DatasetConfig,
|
| 48 |
+
restore_checkpoint_cfg: utils.RestoreCheckpointConfig,
|
| 49 |
+
partitioner: partitioning.BasePartitioner,
|
| 50 |
+
output_dir: str):
|
| 51 |
+
"""Evaluation function.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
model: The model object to use for inference.
|
| 55 |
+
dataset_cfg: Specification for the dataset to infer based on.
|
| 56 |
+
restore_checkpoint_cfg: Specification for the model parameter checkpoint to
|
| 57 |
+
load.
|
| 58 |
+
partitioner: The partitioner for the model parameters and
|
| 59 |
+
data across devices.
|
| 60 |
+
output_dir: Path to directory to write temporary files and final results.
|
| 61 |
+
"""
|
| 62 |
+
...
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
In the binary, the user-provided gin configuration file will be parsed. It
|
| 66 |
+
specifies which values should be bound to the `evaluate` argument, after which
|
| 67 |
+
we can directly call the fully-bound function without any arguments. Basically,
|
| 68 |
+
we are creating a custom closure of `evaluate` (a la `functools.partial`) but
|
| 69 |
+
specifying the arguments via gin instead of Python.
|
| 70 |
+
|
| 71 |
+
Furthermore, this ability to bind custom arguments is recursive. Not only can we
|
| 72 |
+
bind the arguments of `evaluate`, but we can also bind the constructor and
|
| 73 |
+
method arguments of the instance of `models.BaseModel` that we pass to
|
| 74 |
+
`evaluate`.
|
| 75 |
+
|
| 76 |
+
Let's now look at an example of a gin configuration for parameterizing
|
| 77 |
+
`evaluate`, specifically evaluating a
|
| 78 |
+
[T5 model fine-tuned for closed book question answering](http://goo.gle/t5-cbqa)
|
| 79 |
+
on [Natural Questions Open](https://ai.google.com/research/NaturalQuestions):
|
| 80 |
+
|
| 81 |
+
```py
|
| 82 |
+
from __gin__ import dynamic_registration
|
| 83 |
+
|
| 84 |
+
import __main__ as eval_script
|
| 85 |
+
from t5x import models
|
| 86 |
+
from t5x import partitioning
|
| 87 |
+
from t5x import utils
|
| 88 |
+
|
| 89 |
+
MODEL = %gin.REQUIRED
|
| 90 |
+
|
| 91 |
+
eval_script.evaluate:
|
| 92 |
+
model = %MODEL
|
| 93 |
+
output_dir = '/tmp/t5x_eval'
|
| 94 |
+
dataset_cfg = @utils.DatasetConfig()
|
| 95 |
+
partitioner = @partitioning.PjitPartitioner()
|
| 96 |
+
restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
|
| 97 |
+
|
| 98 |
+
# Load model with overrides.
|
| 99 |
+
include 'models/t5_large.gin'
|
| 100 |
+
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
|
| 101 |
+
|
| 102 |
+
utils.DatasetConfig:
|
| 103 |
+
mixture_or_task_name = 'natural_questions_open'
|
| 104 |
+
split = 'test'
|
| 105 |
+
task_feature_lengths = None
|
| 106 |
+
batch_size = 32
|
| 107 |
+
shuffle = False
|
| 108 |
+
seed = 0
|
| 109 |
+
use_cached = False
|
| 110 |
+
pack = False
|
| 111 |
+
use_custom_packing_ops = False
|
| 112 |
+
module = 'google_research.t5_closed_book_qa.t5_cbqa.tasks'
|
| 113 |
+
|
| 114 |
+
partitioning.PjitPartitioner:
|
| 115 |
+
num_partitions = 1
|
| 116 |
+
|
| 117 |
+
utils.RestoreCheckpointConfig:
|
| 118 |
+
mode = 'specific'
|
| 119 |
+
path = 'gs://t5-data/pretrained_models/cbqa/large_ssm_nqo'
|
| 120 |
+
assignment_map = None
|
| 121 |
+
strict = True
|
| 122 |
+
dtype = None
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
Let's go through this block-by-block.
|
| 126 |
+
|
| 127 |
+
```py
|
| 128 |
+
from __gin__ import dynamic_registration
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
The first line imports a new gin feature (see cl/372624800 for more details) to
|
| 132 |
+
allow us to register functions and objects for configuration from within the gin
|
| 133 |
+
file itself without having to modify or decorate functions from the imported
|
| 134 |
+
packages.
|
| 135 |
+
|
| 136 |
+
```py
|
| 137 |
+
import __main__ as eval_script
|
| 138 |
+
from t5x import models
|
| 139 |
+
from t5x import utils
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
The second block imports the modules containing the components we plan to
|
| 143 |
+
configure in this file and is required for dynamic registration. Note that only
|
| 144 |
+
those functions and objects that we specify below will actually be configured,
|
| 145 |
+
not everything in the module. Also, as is the case in Python, the binary module
|
| 146 |
+
is referred as `__main__`, although we rename it to `eval_script` for clarity in
|
| 147 |
+
the rest of the config.
|
| 148 |
+
|
| 149 |
+
```py
|
| 150 |
+
MODEL = %gin.REQUIRED
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
The third block creates a
|
| 154 |
+
[gin macro](https://github.com/google/gin-config/tree/master/docs/index.md#gin-macros)
|
| 155 |
+
(essentially a lazy reference) and for now sets it to refer to the special macro
|
| 156 |
+
`gin.REQUIRED`, which will cause a failure during parsing of the configuration
|
| 157 |
+
if not updated via a later assignment in the config file or command-line flags
|
| 158 |
+
(see [below](#command-line-usage)).
|
| 159 |
+
|
| 160 |
+
```py
|
| 161 |
+
eval_script.evaluate:
|
| 162 |
+
model = %MODEL
|
| 163 |
+
output_dir = '/tmp/t5x_eval'
|
| 164 |
+
dataset_cfg = @utils.DatasetConfig()
|
| 165 |
+
partitioner = @partitioning.PjitPartitioner()
|
| 166 |
+
restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
The fourth block specifies the binding for the `evaluate` function. For `model`,
|
| 170 |
+
we pass the value of the `MODEL` macro (to be defined later). For `output_dir`
|
| 171 |
+
we pass a string path. For `dataset_cfg`, `restore_checkpoint_cfg`, and
|
| 172 |
+
`partitioner`, we pass instantiations of `DatasetConfig`,
|
| 173 |
+
`RestoreCheckpointConfig`, and `PjitPartitioner`, which are defined in
|
| 174 |
+
[utils.py](https://github.com/google-research/t5x/blob/main/t5x/utils.py) and
|
| 175 |
+
[partitioning.py](https://github.com/google-research/t5x/blob/main/t5x/partitioning.py)
|
| 176 |
+
respectively. The '@' prefix tells gin that the following is a configured
|
| 177 |
+
function or class, and the '()' suffix signifies that it should be called (in
|
| 178 |
+
the cases of class, this means calling the constructor). If we wanted to pass in
|
| 179 |
+
the closure (or a partially bound) function instead of its return value, we
|
| 180 |
+
would leave off the parentheses.
|
| 181 |
+
|
| 182 |
+
The remainder of the file deals with defining the `MODEL` macro and fully
|
| 183 |
+
binding these constructors.
|
| 184 |
+
|
| 185 |
+
```py
|
| 186 |
+
# Load model with overrides.
|
| 187 |
+
include 't5x/examples/t5/t5_1_1/large.gin'
|
| 188 |
+
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
Although we could define `MODEL = model.EncoderDecoderModel()` here, we prefer
|
| 192 |
+
to create a separate gin file that defines it. This makes it easier to reuse
|
| 193 |
+
parts of the common configurations. All of the bindings in the newly included
|
| 194 |
+
file are read and override any conflicting ones defined so far in this file.
|
| 195 |
+
It's equivalent to copy and pasting the contents of the included file at this
|
| 196 |
+
location in the config. If you want to see how the model itself is instantiated,
|
| 197 |
+
you can refer to
|
| 198 |
+
[t5_1_1/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/large.gin)
|
| 199 |
+
(which simply overrides a few values from
|
| 200 |
+
[t5_1_1/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/base.gin)).
|
| 201 |
+
|
| 202 |
+
The final line of this block shows an example of how you can modify the default
|
| 203 |
+
arguments of the `EncoderDecoderModel` instance referenced by `%MODEL`, in this
|
| 204 |
+
case changing the default beam size it will use during prediction. Notice that
|
| 205 |
+
since we are only binding one argument here, we choose to write it on a single
|
| 206 |
+
line instead of using the block binding syntax used elsewhere in the file.
|
| 207 |
+
|
| 208 |
+
```py
|
| 209 |
+
utils.DatasetConfig:
|
| 210 |
+
mixture_or_task_name = 'natural_questions_open'
|
| 211 |
+
split = 'test'
|
| 212 |
+
task_feature_lengths = None
|
| 213 |
+
batch_size = 32
|
| 214 |
+
shuffle = False
|
| 215 |
+
seed = 0
|
| 216 |
+
use_cached = False
|
| 217 |
+
pack = False
|
| 218 |
+
use_custom_packing_ops = False
|
| 219 |
+
module = 'google_research.t5_closed_book_qa.t5_cbqa.tasks'
|
| 220 |
+
|
| 221 |
+
partitioning.PjitPartitioner:
|
| 222 |
+
num_partitions = 1
|
| 223 |
+
|
| 224 |
+
utils.RestoreCheckpointConfig:
|
| 225 |
+
mode = 'specific'
|
| 226 |
+
path = 'gs://t5-data/pretrained_models/cbqa/large_ssm_nqo'
|
| 227 |
+
assignment_map = None
|
| 228 |
+
strict = True
|
| 229 |
+
dtype = None
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
The last 3 blocks are fairly straightforward. They are effectively setting the
|
| 233 |
+
attributes of these dataclasses by binding values to their constructors that
|
| 234 |
+
will be used when they are instantiated and passed to `evaluate`, as specified
|
| 235 |
+
in the fourth block.
|
| 236 |
+
|
| 237 |
+
### Scoping
|
| 238 |
+
|
| 239 |
+
The above example lacks one key component of gin:
|
| 240 |
+
[scopes](https://github.com/google/gin-config/blob/main/README.md#4-configuring-the-same-function-in-different-ways-scopes).
|
| 241 |
+
|
| 242 |
+
What happens if you need to use a class or function multiple times but with
|
| 243 |
+
different bound values?
|
| 244 |
+
|
| 245 |
+
A clear example of this is in the top-level `train` function (in
|
| 246 |
+
[train.py](https://github.com/google-research/t5x/blob/main/t5x/train.py)). The call signature
|
| 247 |
+
includes 3 different instances of `utils.DatasetConfig`: one for the train
|
| 248 |
+
dataset, one for the "train-eval" dataset (used for evaluation with teacher
|
| 249 |
+
forcing), and one for the "infer-eval" dataset (used for evaluation with
|
| 250 |
+
inference/decoding).
|
| 251 |
+
|
| 252 |
+
The solution is to prefix each instance with a unique identifier both when
|
| 253 |
+
specifying where it is to be passed to `train` and when binding its arguments.
|
| 254 |
+
For example, the gin file might look like the following (skipping the irrelevant
|
| 255 |
+
bits):
|
| 256 |
+
|
| 257 |
+
```py
|
| 258 |
+
...
|
| 259 |
+
|
| 260 |
+
train_script.train:
|
| 261 |
+
train_dataset_cfg = @train/utils.DatasetConfig()
|
| 262 |
+
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
|
| 263 |
+
infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
|
| 264 |
+
...
|
| 265 |
+
|
| 266 |
+
train/utils.DatasetConfig:
|
| 267 |
+
mixture_or_task_name = 'train_mixture'
|
| 268 |
+
split = 'train'
|
| 269 |
+
...
|
| 270 |
+
|
| 271 |
+
train_eval/utils.DatasetConfig:
|
| 272 |
+
mixture_or_task_name = 'eval_mixture'
|
| 273 |
+
split = 'validation'
|
| 274 |
+
...
|
| 275 |
+
|
| 276 |
+
infer_eval/utils.DatasetConfig:
|
| 277 |
+
mixture_or_task_name = 'eval_mixture'
|
| 278 |
+
split = 'test'
|
| 279 |
+
...
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
We have therefore configured 3 different scoped-versions of
|
| 283 |
+
`utils.DatasetConfig` producing 3 separate instances that are passed to `train`.
|
| 284 |
+
|
| 285 |
+
Note that these three scopes will all inherit from the base scope, so if you
|
| 286 |
+
want to set a shared binding, you may directly configure `utils.DatasetConfig`
|
| 287 |
+
without a scope prefix.
|
| 288 |
+
|
| 289 |
+
## Command-Line Usage
|
| 290 |
+
|
| 291 |
+
So now that you have a gin config, how do you pass it to the script? There are
|
| 292 |
+
two ways: gin files and override flags.
|
| 293 |
+
|
| 294 |
+
1. **Gin Files** You have already seen an example of a gin file above. You can
|
| 295 |
+
specify the gin file(s) to use in your script via the `--gin_file` flag. If
|
| 296 |
+
you want to load multiple gin files, you can set the flag multiple times and
|
| 297 |
+
the files will be loaded in order, with the second potentially overriding
|
| 298 |
+
the first when there are conflicts. It is possible to supply a
|
| 299 |
+
comma-separate list of search prefixes via `--gin_search_paths` and then
|
| 300 |
+
only specify the relative path to the `--gin_file` flags. However, we
|
| 301 |
+
strongly recommend against using `--gin_search_paths`. Using absolute paths
|
| 302 |
+
via the `--gin_file` flags will reduce sources of ambiguity and improve the
|
| 303 |
+
consistency of your scripts.
|
| 304 |
+
|
| 305 |
+
1. **Override Flags** Gin flags allow for more fine-grained overrides of any
|
| 306 |
+
configurable aspect of your run. These flags follow the single-line binding
|
| 307 |
+
format from the above example with the addition of a `--gin.` prefix. For
|
| 308 |
+
example, if you want to override the dataset shuffling, you can set
|
| 309 |
+
`--gin.utils.DatasetConfig.shuffle=False`. In the train setting where there
|
| 310 |
+
are multiple datasets, you must supply the appropriate scope, e.g.,
|
| 311 |
+
`--gin.train/utils.DatasetConfig.shuffle=False`. These bindings are
|
| 312 |
+
processed in order *after* the gin files are loaded, and therefore overwrite
|
| 313 |
+
any previously assigned value in the gin files.
|
| 314 |
+
|
| 315 |
+
**Note:** when supplying a string, dict, list, or tuple value via a flag, you
|
| 316 |
+
must put it in quotes. In the case of strings, it requires escaped quotes
|
| 317 |
+
(`\"<string>\"`). For example: `--gin.utils.DatasetConfig.split=\"validation\"`,
|
| 318 |
+
`--gin.utils.DatasetConfig.task_feature_lengths="{'inputs': 512, 'targets':
|
| 319 |
+
84}"`, and `--gin.dense.MlpBlock.activations="('dense', 'gelu')"`
|
| 320 |
+
|
| 321 |
+
### An Example
|
| 322 |
+
|
| 323 |
+
An example where you may need multiple files is with the `train` script.
|
| 324 |
+
|
| 325 |
+
You can first specify which model you want to train by supplying a gin file
|
| 326 |
+
containing its definition, for example:
|
| 327 |
+
[t5_1_1/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin).
|
| 328 |
+
|
| 329 |
+
You may then specify a run config that supplies some of the common defaults. For
|
| 330 |
+
example, if you are doing pretraining you can use
|
| 331 |
+
[runs/pretrain.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/pretrain.gin),
|
| 332 |
+
and if you are doing finetuning, you can use
|
| 333 |
+
[runs/finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin).
|
| 334 |
+
|
| 335 |
+
We can apply these two files with the following command:
|
| 336 |
+
|
| 337 |
+
```sh
|
| 338 |
+
python -m t5x.train_unfragmented \
|
| 339 |
+
--gin_file=t5x/examples/t5/t5_1_1/small.gin \
|
| 340 |
+
--gin_file=t5x/configs/runs/finetune.gin \
|
| 341 |
+
--logtostderr
|
| 342 |
+
```
|
| 343 |
+
|
| 344 |
+
However, running this command will give you an error like the following:
|
| 345 |
+
|
| 346 |
+
```sh
|
| 347 |
+
ValueError: MODEL_DIR/macro.value set to `%gin.REQUIRED` but not subsequently overridden.
|
| 348 |
+
```
|
| 349 |
+
|
| 350 |
+
This is because the config still includes some `gin.REQUIRED` macros that you'll
|
| 351 |
+
need to override with the details of your run. At the top of
|
| 352 |
+
[runs/finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin)
|
| 353 |
+
you'll see the list of required overrides, which we will populate for finetuning
|
| 354 |
+
on WMT in the updated launch command here:
|
| 355 |
+
|
| 356 |
+
```sh
|
| 357 |
+
python -m t5x.train_unfragmented \
|
| 358 |
+
--gin_file=t5x/examples/t5/t5_1_1/small.gin \
|
| 359 |
+
--gin_file=t5x/configs/runs/finetune.gin \
|
| 360 |
+
--gin.MIXTURE_OR_TASK_NAME=\"wmt_t2t_ende_v003\" \
|
| 361 |
+
--gin.MIXTURE_OR_TASK_MODULE=\"t5.data.mixtures\" \
|
| 362 |
+
--gin.TASK_FEATURE_LENGTHS="{'inputs': 256, 'targets': 256}" \
|
| 363 |
+
--gin.TRAIN_STEPS=1_020_000 \
|
| 364 |
+
--gin.MODEL_DIR=\"/tmp/t5_1_1_base_finetune_gin\" \
|
| 365 |
+
--gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000\" \
|
| 366 |
+
--logtostderr
|
| 367 |
+
```
|
| 368 |
+
|
| 369 |
+
Note you may still override any registered bindings. For example, to disable
|
| 370 |
+
inference evaluation you may add `--gin.train.infer_eval_dataset_cfg=None`.
|
| 371 |
+
|
| 372 |
+
### A File-only Example
|
| 373 |
+
|
| 374 |
+
At the beginning of the primer, we saw a fully-specified run config. We can do
|
| 375 |
+
something similar with the previous example to create a self-contained run
|
| 376 |
+
configuration.
|
| 377 |
+
[t5_1_1/examples/small_wmt_finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin)
|
| 378 |
+
is just such an example that allows you to exactly duplicate the previous launch
|
| 379 |
+
command simply by calling:
|
| 380 |
+
|
| 381 |
+
```sh
|
| 382 |
+
python -m t5x.train_unfragmented \
|
| 383 |
+
--gin_file=t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin \
|
| 384 |
+
--gin.MODEL_DIR=\"/tmp/t5_1_1_small_finetune_gin\" \
|
| 385 |
+
--logtostderr
|
| 386 |
+
```
|
| 387 |
+
|
| 388 |
+
## Logging
|
| 389 |
+
|
| 390 |
+
After your gin files and flag overrides are parsed, the complete configuration
|
| 391 |
+
will be logged to INFO, written to `config.gin` in the output directory, and
|
| 392 |
+
added to a TensorBoard summary.
|
| 393 |
+
|
| 394 |
+
It is highly recommended that you review this generated config to ensure that
|
| 395 |
+
your overrides are working as expected.
|
t5x-main/docs/usage/gpu-usage.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GPU Scripts
|
| 2 |
+
|
| 3 |
+
# Warning!
|
| 4 |
+
An updated version of T5x with optimized GPU performance (18-80% perf gains!) and new features, including FP8 with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) and H100 support can be found here: [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x).
|
| 5 |
+
-----
|
| 6 |
+
**NVIDIA no longer recommends using this repository and won't be updating it further.**
|
| 7 |
+
-----
|
| 8 |
+
|
| 9 |
+
The [t5x/contrib/gpu](../../t5x/contrib/gpu) directory contains scripts optimized for GPU usage.
|
| 10 |
+
|
| 11 |
+
Install with `pip install -r pile_requirements.txt` to get all pile dependencies.
|
| 12 |
+
|
| 13 |
+
## Building the container
|
| 14 |
+
The Dockerfile in `t5x/contrib/gpu` given will build a container with all gpu/pile dependencies. It can be built with `t5x/contrib/gpu/docker/build.sh <name>`
|
| 15 |
+
|
| 16 |
+
## Running interactively
|
| 17 |
+
Note: this should only be done with singlenode jobs and/or for downloading the pile. Use `t5x/contrib/gpu/docker/interactive_pull_and_launch.sh`. This takes arguments for the URL to pull a container from and the location of the dataset directory to mount. For example:
|
| 18 |
+
|
| 19 |
+
`t5x/contrib/gpu/docker/interactive_pull_and_launch.sh [URL] /my/dataset/dir`
|
| 20 |
+
|
| 21 |
+
## Downloading The Pile
|
| 22 |
+
Run `download_the_pile.py` to download the pile. It will download to the directory set in the environment variable: `TFDS_DATA_DIR`. After that, set the `TFDS_DATA_DIR` to the same directory in your scripts to use.
|
| 23 |
+
|
| 24 |
+
## Single Node runs
|
| 25 |
+
Pretraining and Finetuning can be done with `singlenode_*.sh`. These will build a T5X model with the Adam optimizer and relevant parameters. These will allow multi-gpu on one host.
|
| 26 |
+
|
| 27 |
+
## Multi Node runs
|
| 28 |
+
For a SLURM+pyxis cluster, `example*.sub` files provide example slurm submit files (edit with your details), which call `multiprocess*.sh` to execute training. You can add a binding script in the `.sub` file for your cluster, or remove it entirely (dropping some throughput)
|
| 29 |
+
|
| 30 |
+
## Convergence
|
| 31 |
+
For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2048 for all other models, where GBS is defined as #GPUs * BS/GPU / Tensor Parallel(TP). Below are example (tested) hardware topologies on NVIDIA DGX A100 (8x A100 80G) nodes.
|
| 32 |
+
|
| 33 |
+
| size | #GPUs | TP | BS / GPU | Sequences/Sec | Estimated Walltime | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log |
|
| 34 |
+
| ---- | ----- | ----- | -------- | ------------- | ------------------ | ------------------ | ------------------ | --------------- |
|
| 35 |
+
| small| 8 | 1 | 256 | ~3168 | 7.48 days | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) |
|
| 36 |
+
| large| 64 | 1 | 32 | ~3886 | 6.10 days | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) |
|
| 37 |
+
| xl | 256 | 1 | 8 | ~3652 | 6.49 days | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) |
|
| 38 |
+
| xxl | 512 | 8 | 36 | ~1346 | 19.81 days | N/A(partial run) | N/A(partial run) | N/A(partial run)|
|
| 39 |
+
|
| 40 |
+
Note: Convergence (as shown in log) was not necessarily done with the hardware topology listed, but the listed topology is tested. Estimated Walltime is calculated assuming full throughput (seq/sec) continuously. In practice, there are compilation overheads at the beginning of each run/restart(in cluster settings) + checkpointing overheads (if any).
|
| 41 |
+
|
| 42 |
+
(More perf improvements coming soon!)
|
| 43 |
+
|
| 44 |
+
Other hyperparameters are specified in the associated pile `gin` files in the `contrib/gpu/t5/t5_1_1/examples` directory.
|
| 45 |
+
|
| 46 |
+
## Pretraining run commands
|
| 47 |
+
|
| 48 |
+
### Singlenode
|
| 49 |
+
small:
|
| 50 |
+
|
| 51 |
+
`t5x/contrib/gpu/t5/scripts_gpu/singlenode_pretrain_pile.sh small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR} {GRADIENT_ACCUMULATION (1 by default)}`
|
| 52 |
+
|
| 53 |
+
Finetuning:
|
| 54 |
+
MNLI v2:
|
| 55 |
+
`t5x/contrib/gpu/t5/scripts_gpu/singlenode_ft_frompile.sh mnli2 small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR(to restore pretrained checkpoint from)} {GRADIENT_ACCUMULATION}`
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
### Multinode
|
| 59 |
+
Arguments are as such:
|
| 60 |
+
|
| 61 |
+
`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}`
|
| 62 |
+
|
| 63 |
+
small:
|
| 64 |
+
|
| 65 |
+
`sbatch -N 1 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub small bfloat16 8 256 {MODEL_DIR} 1 1`
|
| 66 |
+
|
| 67 |
+
large:
|
| 68 |
+
|
| 69 |
+
`sbatch -N 8 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub large bfloat16 8 32 {MODEL_DIR} 1 1`
|
| 70 |
+
|
| 71 |
+
xl:
|
| 72 |
+
|
| 73 |
+
`sbatch -N 32 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub xl bfloat16 8 8 {MODEL_DIR} 1 1`
|
| 74 |
+
|
| 75 |
+
Finetuning commands simply change the script and have an additional `{FT_TASK}` as the first argument (along with relevant hyperparameter changes). Your `MODEL_DIR` should contain the pretrained checkpoint to restore from.
|
| 76 |
+
|
| 77 |
+
MNLI v2:
|
| 78 |
+
|
| 79 |
+
`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub mnli2 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}`
|
| 80 |
+
|
| 81 |
+
SQuAD v1.1
|
| 82 |
+
|
| 83 |
+
`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub squad1 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}`
|
| 84 |
+
|
| 85 |
+
On all finetuning runs, we use a Global Batch Size of 128 with bfloat16 precision.
|
| 86 |
+
|
| 87 |
+
WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up.
|
t5x-main/docs/usage/index.rst
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
T5X Usage Guides
|
| 2 |
+
================
|
| 3 |
+
|
| 4 |
+
.. toctree::
|
| 5 |
+
:maxdepth: 2
|
| 6 |
+
|
| 7 |
+
pretrain.md
|
| 8 |
+
finetune.md
|
| 9 |
+
eval.md
|
| 10 |
+
infer.md
|
| 11 |
+
auxiliary.md
|
| 12 |
+
decoding.md
|
| 13 |
+
metrics.md
|
| 14 |
+
partitioning.md
|
| 15 |
+
gin.md
|
| 16 |
+
|
t5x-main/docs/usage/infer-files.md
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Running inference on a Model
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
## Introduction
|
| 5 |
+
|
| 6 |
+
This page outlines the steps to run inference a model with T5X on files
|
| 7 |
+
containing
|
| 8 |
+
[TensorFlow Examples](https://www.tensorflow.org/api_docs/python/tf/train/Example).
|
| 9 |
+
|
| 10 |
+
## Overview
|
| 11 |
+
|
| 12 |
+
Running inference on a model with T5X using TF Example files consists of the
|
| 13 |
+
following steps:
|
| 14 |
+
|
| 15 |
+
1. Choose the model to run inference on.
|
| 16 |
+
1. Choose the TF Example files to run inference on.
|
| 17 |
+
1. Write a Gin file that configures the model, file source and other details of
|
| 18 |
+
your inference run.
|
| 19 |
+
1. Launch your experiment locally or on XManager.
|
| 20 |
+
1. Monitor your experiment and access predictions.
|
| 21 |
+
|
| 22 |
+
These steps are explained in detail in the following sections. An example run
|
| 23 |
+
that runs inference on a fine-tuned T5-1.1-Small checkpoint on `tfrecord` files
|
| 24 |
+
containing the
|
| 25 |
+
[(Open Domain) Natural Questions benchmark](https://ai.google.com/research/NaturalQuestions/)
|
| 26 |
+
is also showcased.
|
| 27 |
+
|
| 28 |
+
## Step 1: Choose a model
|
| 29 |
+
|
| 30 |
+
To run inference on a model, you need a Gin config file that defines the model
|
| 31 |
+
params, and the model checkpoint to load from. For this example, a T5-1.1-Small
|
| 32 |
+
model fine-tuned on the
|
| 33 |
+
[`natural_questions_open_test`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=141&rcl=370261021)
|
| 34 |
+
SeqIO Task will be used:
|
| 35 |
+
|
| 36 |
+
+ Model checkpoint -
|
| 37 |
+
[`cbqa/small_ssm_nq/model.ckpt-1110000`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/cbqa/small_ssm_nq/)
|
| 38 |
+
+ Model Gin file -
|
| 39 |
+
[`models/t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
|
| 40 |
+
|
| 41 |
+
If you would like to fine-tune your model before inference, please follow the
|
| 42 |
+
[fine-tuning](finetune.md) tutorial, and continue to Step 2.
|
| 43 |
+
|
| 44 |
+
## Step 2: Choose a TF Example file source
|
| 45 |
+
|
| 46 |
+
T5X supports running inference on `tfrecord`, `recordio` and `sstable` files
|
| 47 |
+
containing TF Examples. For the example run, you will run inference on
|
| 48 |
+
`tfrecord` files containing the `'natural_questions_open'` dataset located here:
|
| 49 |
+
`/path/to/tfds/data/dir/natural_questions_open/1.0.0/natural_questions_open-validation.tfrecord*`.
|
| 50 |
+
Here's an example of a single row of data from this file (you can explore this
|
| 51 |
+
file further using [GQUI](http://shortn/_oNuDhg7jwN)):
|
| 52 |
+
|
| 53 |
+
```json
|
| 54 |
+
{ # (tensorflow.Example) size=101B
|
| 55 |
+
features: { # (tensorflow.Features) size=99B
|
| 56 |
+
feature: { # (tensorflow.Features.FeatureEntry) size=27B
|
| 57 |
+
key: "answer" # size=6
|
| 58 |
+
value: { # (tensorflow.Feature) size=17B
|
| 59 |
+
bytes_list: { # (tensorflow.BytesList) size=15B
|
| 60 |
+
value: [ "Jason Flemyng" ] # size=13
|
| 61 |
+
} # features.feature[0].value.bytes_list
|
| 62 |
+
} # features.feature[0].value
|
| 63 |
+
} # features.feature[0]
|
| 64 |
+
feature: { # (tensorflow.Features.FeatureEntry) size=68B
|
| 65 |
+
key: "question" # size=8
|
| 66 |
+
value: { # (tensorflow.Feature) size=56B
|
| 67 |
+
bytes_list: { # (tensorflow.BytesList) size=54B
|
| 68 |
+
value: [ "who played hyde in league of extraordinary gentlemen" ] # size=52
|
| 69 |
+
} # features.feature[1].value.bytes_list
|
| 70 |
+
} # features.feature[1].value
|
| 71 |
+
} # features.feature[1]
|
| 72 |
+
} # features
|
| 73 |
+
}
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## Step 3: Write a Gin Config
|
| 77 |
+
|
| 78 |
+
After choosing the model and file source for your run, the next step is to
|
| 79 |
+
configure your run using Gin. If you're not familiar with Gin, reading the
|
| 80 |
+
[T5X Gin Primer](gin.md) is recommended. T5X provides a Gin file that configures
|
| 81 |
+
the T5X inference job (located at
|
| 82 |
+
[`t5x/configs/runs/infer_from_tfexample_file.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer_from_tfexample_file.gin))
|
| 83 |
+
to run inference on TF Example files, and expects a few params from you. These
|
| 84 |
+
params can be specified in a separate Gin file, or via commandline flags.
|
| 85 |
+
Following are the required params:
|
| 86 |
+
|
| 87 |
+
+ `CHECKPOINT_PATH`: This is the path to the model checkpoint (from Step 1).
|
| 88 |
+
For the example run, set this to
|
| 89 |
+
`'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'`.
|
| 90 |
+
+ `TF_EXAMPLE_FILE_PATHS`: This is a list of paths or glob patterns to read TF
|
| 91 |
+
Examples from. For the example run, set this to
|
| 92 |
+
`['/path/to/tfds/data/dir/natural_questions_open/1.0.0/natural_questions_open-validation.tfrecord*']`.
|
| 93 |
+
+ `TF_EXAMPLE_FILE_TYPE`: This is the TF Example file format. Currently
|
| 94 |
+
supported file formats are `tfrecord`, `recordio` and `sstable`. For the
|
| 95 |
+
example run, set this to `'tfrecord'`.
|
| 96 |
+
+ `FEATURE_LENGTHS`: This is a dict mapping feature key to maximum int length
|
| 97 |
+
for that feature. the TF Example features are truncated to the provided
|
| 98 |
+
value. For the example run, set this to `{'inputs': 38, 'targets': 18}`,
|
| 99 |
+
which is the maximum token length for the test set.
|
| 100 |
+
+ `INFER_OUTPUT_DIR`: A path to write inference outputs to. When launching
|
| 101 |
+
using XManager, this path is automatically set and can be accessed from the
|
| 102 |
+
XManager Artifacts page. When running locally using Blaze, you can
|
| 103 |
+
explicitly pass a directory using a flag. Launch commands are provided in
|
| 104 |
+
the next step.
|
| 105 |
+
|
| 106 |
+
In addition to the above params, you may also need to override the
|
| 107 |
+
`create_task_from_tfexample_file.inputs_key` param based on the data format (it
|
| 108 |
+
is set to `'inputs'` by default. For the example run, the `'question'` key
|
| 109 |
+
contains the input (see Step 2), so add the following to your Gin config:
|
| 110 |
+
|
| 111 |
+
```gin
|
| 112 |
+
create_task_from_tfexample_file.inputs_key = 'question'
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
Additionally, you will need to import the
|
| 116 |
+
[`infer_from_tfexample_file.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer_from_tfexample_file.gin)
|
| 117 |
+
and the Gin file for the model, which for the example run is
|
| 118 |
+
[`t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
|
| 119 |
+
|
| 120 |
+
```gin
|
| 121 |
+
include 'runs/infer_from_tfexample_file.gin'
|
| 122 |
+
include 'models/t5_1_1_small.gin'
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
Note that the `include` statements use relative paths in this example. You will
|
| 126 |
+
pass an appropriate `gin_search_paths` flag to locate these files when launching
|
| 127 |
+
your run. Absolute paths to Gin files can also be used, e.g.
|
| 128 |
+
|
| 129 |
+
```gin
|
| 130 |
+
include 't5x/configs/runs/infer_from_tfexample_file.gin'
|
| 131 |
+
include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin'
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
Finally, your Gin file should look like this:
|
| 135 |
+
|
| 136 |
+
```gin
|
| 137 |
+
include 'runs/infer_from_tfexample_file.gin'
|
| 138 |
+
include 'models/t5_1_1_small.gin'
|
| 139 |
+
|
| 140 |
+
CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'
|
| 141 |
+
TF_EXAMPLE_FILE_PATHS = ['/path/to/tfds/data/dir/natural_questions_open/1.0.0/natural_questions_open-validation.tfrecord*']
|
| 142 |
+
TF_EXAMPLE_FILE_TYPE = 'tfrecord'
|
| 143 |
+
FEATURE_LENGTHS = {'inputs': 38, 'targets': 18}
|
| 144 |
+
create_task_from_tfexample_file.inputs_key = 'question'
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
See
|
| 148 |
+
[`t5x/configs/examples/inference/t5_1_1_small_cbqa_natural_questions_tfexample.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions_tfexample.gin)
|
| 149 |
+
for this example. Make sure that your Gin file is linked as a data dependency to
|
| 150 |
+
the T5X inference
|
| 151 |
+
[binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=74;rcl=398627055). If your
|
| 152 |
+
Gin file is not included, see the
|
| 153 |
+
[Advanced Topics section](#custom-t5x-binaries) at the end of this tutorial for
|
| 154 |
+
instructions to add it, or skip writing a Gin file and pass the above params as
|
| 155 |
+
flags when launching the inference job (see instructions in Step 4).
|
| 156 |
+
|
| 157 |
+
## Step 4: Launch your experiment
|
| 158 |
+
|
| 159 |
+
To launch your experiment locally (for debugging only; larger checkpoints may
|
| 160 |
+
cause issues), run the following on commandline:
|
| 161 |
+
|
| 162 |
+
```sh
|
| 163 |
+
INFER_OUTPUT_DIR="/tmp/model-infer/"
|
| 164 |
+
python -m t5x.infer_unfragmented \
|
| 165 |
+
--gin_file=t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions_tfexample.gin \
|
| 166 |
+
--gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
|
| 167 |
+
--alsologtostderr
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
Note that multiple comma-separated paths can be passed to the `gin_search_paths`
|
| 171 |
+
flag, and these paths should contain all Gin files used or included in your
|
| 172 |
+
experiment.
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
After inference has completed, you can view predictions in the `jsonl` files in
|
| 176 |
+
the output dir. JSON data is written in chunks and combined at the end of the
|
| 177 |
+
inference run. Refer to [Sharding](#sharding) and
|
| 178 |
+
[Checkpointing](#checkpointing) sections for more details.
|
| 179 |
+
|
| 180 |
+
## Next Steps
|
| 181 |
+
|
| 182 |
+
Now that you have successfully run inference on a model, here are some topics
|
| 183 |
+
you might want to explore next:
|
| 184 |
+
|
| 185 |
+
+ [Fine-tuning a model.](finetune.md)
|
| 186 |
+
+ [Evaluating a model.](eval.md)
|
| 187 |
+
+ [Training a model from scratch.](pretrain.md)
|
| 188 |
+
|
| 189 |
+
We also touch upon a few advanced topics related to inference below that might
|
| 190 |
+
be useful, especially when customizing your inference job.
|
| 191 |
+
|
| 192 |
+
## Advanced Topics
|
| 193 |
+
|
| 194 |
+
### Dataset Sharding {#sharding .no-toc}
|
| 195 |
+
|
| 196 |
+
You can run inference in parallel across multiple TPU slices by setting the
|
| 197 |
+
`num_shards` flag when running using XManager. When `num_shards > 1`, the
|
| 198 |
+
dataset is interleaved among the shards and the predictions are combined in the
|
| 199 |
+
end; hence the order of examples in the data source and the predictions in the
|
| 200 |
+
output json files will not match (order is guaranteed to match for `num_shards =
|
| 201 |
+
1` or the number of input file shards).
|
| 202 |
+
|
| 203 |
+
### Dataset Checkpointing {#checkpointing .no-toc}
|
| 204 |
+
|
| 205 |
+
You can control dataset checkpointing frequency by overriding the
|
| 206 |
+
`infer.checkpoint_period` in
|
| 207 |
+
[runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin),
|
| 208 |
+
which is set to `100` by default. This means that the dataset is checkpointed
|
| 209 |
+
after running inferences on `checkpoint_period` batches (batches, not examples;
|
| 210 |
+
you can control batch size by overriding `utils.DatasetConfig.batch_size` in
|
| 211 |
+
[runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin), it
|
| 212 |
+
is set to `32` by default).
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
### Defining a custom SeqIO Task/Mixture to run inference on {.no-toc}
|
| 216 |
+
|
| 217 |
+
Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md).
|
t5x-main/docs/usage/infer-seqio.md
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Running inference on a Model
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
## Introduction
|
| 5 |
+
|
| 6 |
+
This page outlines the steps to run inference a model with T5X on Tasks/Mixtures
|
| 7 |
+
defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md).
|
| 8 |
+
|
| 9 |
+
## Overview
|
| 10 |
+
|
| 11 |
+
Running inference on a model with T5X using SeqIO Task/Mixtures consists of the
|
| 12 |
+
following steps:
|
| 13 |
+
|
| 14 |
+
1. Choose the model to run inference on.
|
| 15 |
+
1. Choose the SeqIO Task/Mixture to run inference on.
|
| 16 |
+
1. Write a Gin file that configures the model, SeqIO Task/Mixture and other
|
| 17 |
+
details of your inference run.
|
| 18 |
+
1. Launch your experiment locally or on XManager.
|
| 19 |
+
1. Monitor your experiment and access predictions.
|
| 20 |
+
|
| 21 |
+
These steps are explained in detail in the following sections. An example run
|
| 22 |
+
that runs inference on a fine-tuned T5-1.1-Small checkpoint on the
|
| 23 |
+
[(Open Domain) (Open Domain) Natural Questions benchmark](https://ai.google.com/research/NaturalQuestions/)
|
| 24 |
+
is also showcased.
|
| 25 |
+
|
| 26 |
+
## Step 1: Choose a model
|
| 27 |
+
|
| 28 |
+
To run inference on a model, you need a Gin config file that defines the model
|
| 29 |
+
params, and the model checkpoint to load from. For this example, a T5-1.1-Small
|
| 30 |
+
model fine-tuned on the
|
| 31 |
+
[`natural_questions_open_test`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=141&rcl=370261021)
|
| 32 |
+
SeqIO Task will be used:
|
| 33 |
+
|
| 34 |
+
+ Model checkpoint -
|
| 35 |
+
[`cbqa/small_ssm_nq/model.ckpt-1110000`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/cbqa/small_ssm_nq/)
|
| 36 |
+
+ Model Gin file -
|
| 37 |
+
[`models/t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
|
| 38 |
+
|
| 39 |
+
If you would like to fine-tune your model before inference, please follow the
|
| 40 |
+
[fine-tuning](finetune.md) tutorial, and continue to Step 2.
|
| 41 |
+
|
| 42 |
+
## Step 2: Choose a SeqIO Task/Mixture
|
| 43 |
+
|
| 44 |
+
A SeqIO Task encapsulates the data source, the preprocessing logic to be
|
| 45 |
+
performed on the data before querying the model, the postprocessing logic to be
|
| 46 |
+
performed on model outputs, and the metrics to be computed given the
|
| 47 |
+
postprocessed outputs and targets (for inference, post-processing and metrics
|
| 48 |
+
are irrelevant). A SeqIO Mixture denotes a collection of Tasks and enables
|
| 49 |
+
fine-tuning a model on multiple Tasks.
|
| 50 |
+
|
| 51 |
+
Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/),
|
| 52 |
+
[SuperGLUE](https://super.gluebenchmark.com/),
|
| 53 |
+
[WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate),
|
| 54 |
+
[SQUAD](https://rajpurkar.github.io/SQuAD-explorer/),
|
| 55 |
+
[CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been
|
| 56 |
+
implemented as SeqIO Tasks/Mixtures and can be used directly. These
|
| 57 |
+
Tasks/Mixtures are defined in
|
| 58 |
+
[`third_party/py/t5/data/tasks.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py)
|
| 59 |
+
and
|
| 60 |
+
[`third_party/py/t5/data/mixtures.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/mixtures.py).
|
| 61 |
+
|
| 62 |
+
For the example run, you will run inference on the (Open Domain) Natural
|
| 63 |
+
Questions benchmark, which has been implemented as the `natural_questions_open`
|
| 64 |
+
Task in
|
| 65 |
+
[`/third_party/google_research/google_research/t5_closed_book_qa/t5_cbqa/tasks.py`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=98&rcl=370261021).
|
| 66 |
+
Here's an example of a single row of preprocessed data from this Task:
|
| 67 |
+
|
| 68 |
+
```json
|
| 69 |
+
{
|
| 70 |
+
'inputs_pretokenized': 'nq question: what was the main motive of salt march',
|
| 71 |
+
'inputs': [3, 29, 1824, 822, 10, 125, 47, 8, 711, 10280, 13, 3136, 10556, 1]
|
| 72 |
+
'targets_pretokenized': 'challenge to British authority',
|
| 73 |
+
'targets': [1921, 12, 2390, 5015, 1],
|
| 74 |
+
'answers': ['challenge to British authority']
|
| 75 |
+
}
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## Step 3: Write a Gin Config
|
| 79 |
+
|
| 80 |
+
After choosing the model and SeqIO Task/Mixture for your run, the next step is
|
| 81 |
+
to configure your run using Gin. If you're not familiar with Gin, reading the
|
| 82 |
+
[T5X Gin Primer](gin.md) is recommended. T5X provides a Gin file that configures
|
| 83 |
+
the T5X inference job (located at
|
| 84 |
+
[`runs/infer.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin)) to
|
| 85 |
+
run inference on SeqIO Task/Mixtures, and expects a few params from you. These
|
| 86 |
+
params can be specified in a separate Gin file, or via commandline flags.
|
| 87 |
+
Following are the required params:
|
| 88 |
+
|
| 89 |
+
+ `CHECKPOINT_PATH`: This is the path to the model checkpoint (from Step 1).
|
| 90 |
+
For the example run, set this to
|
| 91 |
+
`'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'`.
|
| 92 |
+
+ `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run
|
| 93 |
+
inference on (from Step 2). For the example run, set this to
|
| 94 |
+
`'natural_questions_open'`.
|
| 95 |
+
+ `MIXTURE_OR_TASK_MODULE`: This is the Python module that contains the SeqIO
|
| 96 |
+
Task or Mixture. For the example run, set this to
|
| 97 |
+
`'google_research.t5_closed_book_qa.t5_cbqa.tasks'`.
|
| 98 |
+
Note that this module must be included as a dependency in the T5X inference
|
| 99 |
+
[binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=74;rcl=398627055). Most
|
| 100 |
+
common Task modules, including `t5_closed_book_qa`, are already included. If
|
| 101 |
+
your module is not included, see the
|
| 102 |
+
[Advanced Topics section](#custom-t5x-binaries) at the end of this tutorial
|
| 103 |
+
for instructions to add it.
|
| 104 |
+
+ `TASK_FEATURE_LENGTHS`: This is a dict mapping feature key to maximum length
|
| 105 |
+
for that feature. After preprocessing, features are truncated to the
|
| 106 |
+
provided value. For the example run, set this to `{'inputs': 38, 'targets':
|
| 107 |
+
18}`, which is the maximum token length for the test set.
|
| 108 |
+
+ `INFER_OUTPUT_DIR`: A path to write inference outputs to. When launching
|
| 109 |
+
using XManager, this path is automatically set and can be accessed from the
|
| 110 |
+
XManager Artifacts page. When running locally using Blaze, you can
|
| 111 |
+
explicitly pass a directory using a flag. Launch commands are provided in
|
| 112 |
+
the next step.
|
| 113 |
+
|
| 114 |
+
In addition to the above params, you will need to import
|
| 115 |
+
[`infer.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin) and the
|
| 116 |
+
Gin file for the model, which for the example run is
|
| 117 |
+
[`t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin).
|
| 118 |
+
|
| 119 |
+
```gin
|
| 120 |
+
include 'runs/infer.gin'
|
| 121 |
+
include 'models/t5_small.gin'
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
Note that the `include` statements use relative paths in this example. You will
|
| 125 |
+
pass an appropriate `gin_search_paths` flag to locate these files when launching
|
| 126 |
+
your run. Absolute paths to Gin files can also be used, e.g.
|
| 127 |
+
|
| 128 |
+
```gin
|
| 129 |
+
include 't5x/configs/runs/infer.gin'
|
| 130 |
+
include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin'
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
Finally, your Gin file should look like this:
|
| 134 |
+
|
| 135 |
+
```gin
|
| 136 |
+
include 'runs/infer.gin'
|
| 137 |
+
include 'models/t5_1_1_small.gin'
|
| 138 |
+
|
| 139 |
+
CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'
|
| 140 |
+
MIXTURE_OR_TASK_NAME = 'closed_book_qa'
|
| 141 |
+
MIXTURE_OR_TASK_MODULE = 'google_research.t5_closed_book_qa.t5_cbqa.tasks'
|
| 142 |
+
TASK_FEATURE_LENGTHS = {'inputs': 38, 'targets': 18}
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
See
|
| 146 |
+
[`t5_1_1_small_cbqa_natural_questions.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions.gin)
|
| 147 |
+
for this example. Make sure that your Gin file is linked as a data dependency to
|
| 148 |
+
the T5X inference
|
| 149 |
+
[binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=74;rcl=398627055). If your
|
| 150 |
+
Gin file is not included, see the
|
| 151 |
+
[Advanced Topics section](#custom-t5x-binaries) at the end of this tutorial for
|
| 152 |
+
instructions to add it, or skip writing a Gin file and pass the above params as
|
| 153 |
+
flags when launching the inference job (see instructions in Step 4).
|
| 154 |
+
|
| 155 |
+
## Step 4: Launch your experiment
|
| 156 |
+
|
| 157 |
+
To launch your experiment locally (for debugging only; larger checkpoints may
|
| 158 |
+
cause issues), run the following on commandline:
|
| 159 |
+
|
| 160 |
+
```sh
|
| 161 |
+
INFER_OUTPUT_DIR="/tmp/model-infer/"
|
| 162 |
+
python -m t5x.infer_unfragmented \
|
| 163 |
+
--gin_file=t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions.gin \
|
| 164 |
+
--gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
|
| 165 |
+
--alsologtostderr
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
Note that multiple comma-separated paths can be passed to the `gin_search_paths`
|
| 169 |
+
flag, and these paths should contain all Gin files used or included in your
|
| 170 |
+
experiment.
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
## Step 5: Monitor your experiment and parse results
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
After inference has completed, you can view predictions in the `jsonl` files in
|
| 177 |
+
the output dir. JSON data is written in chunks and combined at the end of the
|
| 178 |
+
inference run. Refer to [Sharding](#sharding) and
|
| 179 |
+
[Checkpointing](#checkpointing) sections for more details.
|
| 180 |
+
|
| 181 |
+
## Next Steps
|
| 182 |
+
|
| 183 |
+
Now that you have successfully run inference on a model, here are some topics
|
| 184 |
+
you might want to explore next:
|
| 185 |
+
|
| 186 |
+
+ [Fine-tuning a model.](finetune)
|
| 187 |
+
+ [Evaluating a model.](eval)
|
| 188 |
+
+ [Training a model from scratch.](pretrain)
|
| 189 |
+
|
| 190 |
+
We also touch upon a few advanced topics related to inference below that might
|
| 191 |
+
be useful, especially when customizing your inference job.
|
| 192 |
+
|
| 193 |
+
## Advanced Topics
|
| 194 |
+
|
| 195 |
+
### Dataset Sharding {#sharding .no-toc}
|
| 196 |
+
|
| 197 |
+
You can run inference in parallel across multiple TPU slices by setting the
|
| 198 |
+
`num_shards` flag when running using XManager. When `num_shards > 1`, the
|
| 199 |
+
dataset is interleaved among the shards and the predictions are combined in the
|
| 200 |
+
end; hence the order of examples in the data source and the predictions in the
|
| 201 |
+
output json files will not match (order is guaranteed to match for `num_shards =
|
| 202 |
+
1` or the number of input file shards).
|
| 203 |
+
|
| 204 |
+
### Dataset Checkpointing {#checkpointing .no-toc}
|
| 205 |
+
|
| 206 |
+
You can control dataset checkpointing frequency by overriding the
|
| 207 |
+
`infer.checkpoint_period` in
|
| 208 |
+
[runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin),
|
| 209 |
+
which is set to `100` by default. This means that the dataset is checkpointed
|
| 210 |
+
after running inferences on `checkpoint_period` batches (batches, not examples;
|
| 211 |
+
you can control batch size by overriding `utils.DatasetConfig.batch_size` in
|
| 212 |
+
[runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin), it
|
| 213 |
+
is set to `32` by default).
|
| 214 |
+
|
| 215 |
+
### Changing Length and Decoding Strategy {#decoding-strategies .no-toc}
|
| 216 |
+
|
| 217 |
+
By default, T5X does inference using an arg-max decoding strategy, always
|
| 218 |
+
picking the most likely next token. To use random sampling instead, you may
|
| 219 |
+
change any of the following parameters in your gin config:
|
| 220 |
+
|
| 221 |
+
```gin
|
| 222 |
+
decoding.temperature_sample:
|
| 223 |
+
temperature = 1.0
|
| 224 |
+
topk = 1
|
| 225 |
+
topp = 0.0
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
You can also control the number of tokens which get generated by specifying:
|
| 229 |
+
|
| 230 |
+
```gin
|
| 231 |
+
decoding.temperature_sample:
|
| 232 |
+
max_decode_steps = 50
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
More detailed documentation on defining a decoding stategy can be found
|
| 236 |
+
[here](https://github.com/google-research/t5x/blob/main/docs/usage.md/decoding).
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
### Defining a custom SeqIO Task/Mixture to run inference on {.no-toc}
|
| 240 |
+
|
| 241 |
+
Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md).
|