Upload 549 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- CONTRIBUTING.md +32 -0
- IOU_test.py +21 -0
- LICENSE +202 -0
- README.md +217 -0
- __pycache__/owlv2_helper.cpython-310.pyc +0 -0
- __pycache__/owlv2_helper_functions.cpython-310.pyc +0 -0
- auto_bbox.py +266 -0
- big_vision/.gitignore +1 -0
- big_vision/CONTRIBUTING.md +26 -0
- big_vision/LICENSE +201 -0
- big_vision/README.md +499 -0
- big_vision/__init__.py +0 -0
- big_vision/__pycache__/__init__.cpython-310.pyc +0 -0
- big_vision/__pycache__/utils.cpython-310.pyc +0 -0
- big_vision/configs/__init__.py +0 -0
- big_vision/configs/bit_i1k.py +102 -0
- big_vision/configs/bit_i21k.py +85 -0
- big_vision/configs/common.py +188 -0
- big_vision/configs/common_fewshot.py +60 -0
- big_vision/configs/load_and_eval.py +143 -0
- big_vision/configs/mlp_mixer_i1k.py +120 -0
- big_vision/configs/transfer.py +186 -0
- big_vision/configs/vit_i1k.py +177 -0
- big_vision/configs/vit_i21k.py +145 -0
- big_vision/configs/vit_s16_i1k.py +105 -0
- big_vision/datasets/ai2d/ai2d.py +0 -0
- big_vision/datasets/aokvqa/aokvqa.py +0 -0
- big_vision/datasets/chartqa/chartqa.py +0 -0
- big_vision/datasets/coco35l/coco35l.py +0 -0
- big_vision/datasets/core.py +77 -0
- big_vision/datasets/countbenchqa/countbenchqa.py +0 -0
- big_vision/datasets/docvqa/docvqa.py +0 -0
- big_vision/datasets/gqa/gqa.py +0 -0
- big_vision/datasets/imagenet/class_names.py +0 -0
- big_vision/datasets/infovqa/infovqa.py +0 -0
- big_vision/datasets/jsonl.py +177 -0
- big_vision/datasets/nocaps/nocaps.py +0 -0
- big_vision/datasets/okvqa/okvqa.py +0 -0
- big_vision/datasets/pope/pope.py +0 -0
- big_vision/datasets/refcoco/refcoco.py +0 -0
- big_vision/datasets/rsvqa_hr/rsvqa_hr.py +0 -0
- big_vision/datasets/rsvqa_lr/rsvqa_lr.py +0 -0
- big_vision/datasets/scicap/scicap.py +0 -0
- big_vision/datasets/science_qa/science_qa.py +0 -0
- big_vision/datasets/screen2words/screen2words.py +0 -0
- big_vision/datasets/sequence_packing.py +77 -0
- big_vision/datasets/stvqa/stvqa.py +0 -0
- big_vision/datasets/tallyqa/tallyqa.py +0 -0
- big_vision/datasets/textcaps/textcaps.py +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ckpts/clip_vit_l14_with_masks_6c17944 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
ckpts/owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05_209b65b filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
ckpts/owl2-l14-1008-st-ngrams-ft-lvisbase-ens-cold-weight-04_8ca674c filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
images/scenic_design.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
images/scenic_logo.jpg filter=lfs diff=lfs merge=lfs -text
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# How to Contribute
|
| 2 |
+
|
| 3 |
+
Scenic is a platform used for developing new methods and ideas by Google
|
| 4 |
+
researchers, mostly around attention-based models for computer vision or
|
| 5 |
+
multi-modal applications. We encourage forking the repository and continued
|
| 6 |
+
development. We welcome suggestions and contributions to improving Scenic.
|
| 7 |
+
There are a few small guidelines you need to follow.
|
| 8 |
+
|
| 9 |
+
## Contributor License Agreement
|
| 10 |
+
|
| 11 |
+
Contributions to this project must be accompanied by a Contributor License
|
| 12 |
+
Agreement (CLA). You (or your employer) retain the copyright to your
|
| 13 |
+
contribution; this simply gives us permission to use and redistribute your
|
| 14 |
+
contributions as part of the project. Head over to
|
| 15 |
+
<https://cla.developers.google.com/> to see your current agreements on file or
|
| 16 |
+
to sign a new one.
|
| 17 |
+
|
| 18 |
+
You generally only need to submit a CLA once, so if you've already submitted one
|
| 19 |
+
(even if it was for a different project), you probably don't need to do it
|
| 20 |
+
again.
|
| 21 |
+
|
| 22 |
+
## Code Reviews
|
| 23 |
+
|
| 24 |
+
All submissions, including submissions by project members, require review. We
|
| 25 |
+
use GitHub pull requests for this purpose. Consult
|
| 26 |
+
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
|
| 27 |
+
information on using pull requests.
|
| 28 |
+
|
| 29 |
+
## Community Guidelines
|
| 30 |
+
|
| 31 |
+
This project follows
|
| 32 |
+
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
|
IOU_test.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from owlv2_helper_functions import get_iou, boxes_filter
|
| 2 |
+
|
| 3 |
+
boxes = [
|
| 4 |
+
(128.56, 4.57, 732.52, 476.05),
|
| 5 |
+
(569.65, 185.71, 740.31, 244.76),
|
| 6 |
+
(569.65, 185.71, 740.31, 244.76),
|
| 7 |
+
(569.65, 185.71, 740.31, 244.76),
|
| 8 |
+
(101.99, 99.00, 720.12, 88.63),
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
scores = [1.0, 0.99, 0.89, 1.0, 0.99]
|
| 12 |
+
|
| 13 |
+
instances = ['cat', 'dog', 'dog', 'tiger', 'cat']
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
pred_bboxes, pred_scores, instances = boxes_filter(boxes, scores, instances)
|
| 18 |
+
|
| 19 |
+
print(pred_bboxes)
|
| 20 |
+
print(pred_scores)
|
| 21 |
+
print(instances)
|
LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright [yyyy] [name of copyright owner]
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Scenic
|
| 2 |
+
<div style="text-align: left">
|
| 3 |
+
<img align="right" src="https://raw.githubusercontent.com/google-research/scenic/main/images/scenic_logo.png" width="200" alt="scenic logo"></img>
|
| 4 |
+
</div>
|
| 5 |
+
|
| 6 |
+
*Scenic* is a codebase with a focus on research around attention-based models
|
| 7 |
+
for computer vision. Scenic has been successfully used to develop
|
| 8 |
+
classification, segmentation, and detection models for multiple modalities
|
| 9 |
+
including images, video, audio, and multimodal combinations of them.
|
| 10 |
+
|
| 11 |
+
More precisely, *Scenic* is a (i) set of shared light-weight libraries solving
|
| 12 |
+
tasks commonly encountered tasks when training large-scale (i.e. multi-device,
|
| 13 |
+
multi-host) vision models; and (ii) several *projects* containing fully
|
| 14 |
+
fleshed out problem-specific training and evaluation loops using these
|
| 15 |
+
libraries.
|
| 16 |
+
|
| 17 |
+
Scenic is developed in [JAX](https://github.com/jax-ml/jax) and uses
|
| 18 |
+
[Flax](https://github.com/google/flax).
|
| 19 |
+
|
| 20 |
+
### Contents
|
| 21 |
+
* [What we offer](#what-we-offer)
|
| 22 |
+
* [SOTA models and baselines in Scenic](#sota-models-and-baselines-in-scenic)
|
| 23 |
+
* [Philosophy](#philosophy)
|
| 24 |
+
* [Getting started](#getting-started)
|
| 25 |
+
* [Scenic component design](#scenic-component-design)
|
| 26 |
+
* [Citing Scenic](#citing-scenic)
|
| 27 |
+
|
| 28 |
+
## What we offer
|
| 29 |
+
Among others *Scenic* provides
|
| 30 |
+
|
| 31 |
+
* Boilerplate code for launching experiments, summary writing, logging,
|
| 32 |
+
profiling, etc;
|
| 33 |
+
* Optimized training and evaluation loops, losses, metrics, bi-partite matchers,
|
| 34 |
+
etc;
|
| 35 |
+
* Input-pipelines for popular vision datasets;
|
| 36 |
+
* [Baseline models](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines#scenic-baseline-models),
|
| 37 |
+
including strong non-attentional baselines.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
## SOTA models and baselines in *Scenic*
|
| 41 |
+
There are some SOTA models and baselines in Scenic which were either developed
|
| 42 |
+
using Scenic, or have been reimplemented in Scenic:
|
| 43 |
+
|
| 44 |
+
Projects that were developed in Scenic or used it for their experiments:
|
| 45 |
+
|
| 46 |
+
* [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691)
|
| 47 |
+
* [OmniNet: Omnidirectional Representations from Transformers](https://arxiv.org/abs/2103.01075)
|
| 48 |
+
* [Attention Bottlenecks for Multimodal Fusion](https://arxiv.org/abs/2107.00135)
|
| 49 |
+
* [TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?](https://arxiv.org/abs/2106.11297)
|
| 50 |
+
* [Exploring the Limits of Large Scale Pre-training](https://arxiv.org/abs/2110.02095)
|
| 51 |
+
* [The Efficiency Misnomer](https://arxiv.org/abs/2110.12894)
|
| 52 |
+
* [Discrete Representations Strengthen Vision Transformer Robustness](https://arxiv.org/abs/2111.10493)
|
| 53 |
+
* [Pyramid Adversarial Training Improves ViT Performance](https://arxiv.org/abs/2111.15121)
|
| 54 |
+
* [VUT: Versatile UI Transformer for Multi-Modal Multi-Task User Interface Modeling](https://arxiv.org/abs/2112.05692)
|
| 55 |
+
* [CLAY: Learning to Denoise Raw Mobile UI Layouts for Improving Datasets at Scale](https://arxiv.org/abs/2201.04100)
|
| 56 |
+
* [Zero-Shot Text-Guided Object Generation with Dream Fields](https://arxiv.org/abs/2112.01455)
|
| 57 |
+
* [Multiview Transformers for Video Recognition](https://arxiv.org/abs/2201.04288)
|
| 58 |
+
* [PolyViT: Co-training Vision Transformers on Images, Videos and Audio](https://arxiv.org/abs/2111.12993)
|
| 59 |
+
* [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230)
|
| 60 |
+
* [Learning with Neighbor Consistency for Noisy Labels](https://arxiv.org/abs/2202.02200)
|
| 61 |
+
* [Token Turing Machines](https://arxiv.org/pdf/2211.09119.pdf)
|
| 62 |
+
* [Vid2Seq: Large-Scale Pretraining of a Visual Language Model for Dense Video Captioning](https://arxiv.org/pdf/2302.14115.pdf)
|
| 63 |
+
* [AVATAR: Unconstrained Audiovisual Speech Recognition](https://arxiv.org/abs/2206.07684)
|
| 64 |
+
* [Adaptive Computation with Elastic Input Sequence](https://arxiv.org/abs/2301.13195)
|
| 65 |
+
* [Location-Aware Self-Supervised Transformers for Semantic Segmentation](https://arxiv.org/abs/2212.02400)
|
| 66 |
+
* [How can objects help action recognition?](https://openaccess.thecvf.com/content/CVPR2023/html/Zhou_How_Can_Objects_Help_Action_Recognition_CVPR_2023_paper.html)
|
| 67 |
+
* [Verbs in Action: Improving verb understanding in video-language models](https://arxiv.org/abs/2304.06708)
|
| 68 |
+
* [Unified Visual Relationship Detection with Vision and Language Models](https://arxiv.org/abs/2303.08998)
|
| 69 |
+
* [UnLoc: A Unified Framework for Video Localization Tasks](https://arxiv.org/abs/2308.11062)
|
| 70 |
+
* [REVEAL: Retrieval-Augmented Visual-Language Pre-Training with Multi-Source Multimodal Knowledge Memory](https://arxiv.org/abs/2212.05221)
|
| 71 |
+
* [Audiovisual Masked Autoencoders](https://arxiv.org/abs/2212.05922)
|
| 72 |
+
* [MatFormer: Nested Transformer for Elastic Inference](https://arxiv.org/abs/2310.07707)
|
| 73 |
+
* [Pixel Aligned Language Models](https://arxiv.org/abs/2312.09237)
|
| 74 |
+
* [A Generative Approach for Wikipedia-Scale Visual Entity Recognition](https://arxiv.org/abs/2403.02041)
|
| 75 |
+
* [Streaming Dense Video Captioning](https://arxiv.org/abs/2404.01297)
|
| 76 |
+
* [Dense Video Object Captioning from Disjoint Supervision](https://arxiv.org/abs/2306.11729)
|
| 77 |
+
|
| 78 |
+
More information can be found in [projects](https://github.com/google-research/scenic/tree/main/scenic/projects#list-of-projects-hosted-in-scenic).
|
| 79 |
+
|
| 80 |
+
Baselines that were reproduced in Scenic:
|
| 81 |
+
|
| 82 |
+
* [(ViT) An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)
|
| 83 |
+
* [(DETR) End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872)
|
| 84 |
+
* [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159)
|
| 85 |
+
* [(CLIP) Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
|
| 86 |
+
* [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601)
|
| 87 |
+
* [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)
|
| 88 |
+
* [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)
|
| 89 |
+
* [Big Transfer (BiT): General Visual Representation Learning](https://arxiv.org/abs/1912.11370)
|
| 90 |
+
* [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)
|
| 91 |
+
* [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)
|
| 92 |
+
* [PCT: Point Cloud Transformer](https://arxiv.org/abs/2012.09688)
|
| 93 |
+
* [Universal Transformers](https://arxiv.org/abs/1807.03819)
|
| 94 |
+
* [PonderNet](https://arxiv.org/abs/2107.05407)
|
| 95 |
+
* [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)
|
| 96 |
+
* [Rethinking Attention with Performers](https://arxiv.org/abs/2009.14794)
|
| 97 |
+
* [(CenterNet) Objects as Points](https://arxiv.org/abs/1904.07850)
|
| 98 |
+
* [(SAM) Segment Anything](https://arxiv.org/abs/2304.02643)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
More information can be found in [baseline models](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines#scenic-baseline-models).
|
| 102 |
+
|
| 103 |
+
<a name="philosophy"></a>
|
| 104 |
+
## Philosophy
|
| 105 |
+
*Scenic* aims to facilitate rapid prototyping of large-scale vision models. To
|
| 106 |
+
keep the code simple to understand and extend we prefer *forking and
|
| 107 |
+
copy-pasting over adding complexity or increasing abstraction*. Only when
|
| 108 |
+
functionality proves to be widely useful across many models and tasks it may be
|
| 109 |
+
upstreamed to Scenic's shared libraries.
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
<a name="getting_start"></a>
|
| 113 |
+
## Getting started
|
| 114 |
+
* See `projects/baselines/README.md` for a walk-through baseline models and
|
| 115 |
+
instructions on how to run the code.
|
| 116 |
+
* If you would like to contribute to *Scenic*, please check out the
|
| 117 |
+
[Philisophy](#philosophy), [Code structure](#code_structure) and
|
| 118 |
+
[Contributing](CONTRIBUTING.md) sections.
|
| 119 |
+
Should your contribution be a part of the shared libraries, please send us a
|
| 120 |
+
pull request!
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
### Quickstart
|
| 124 |
+
You will need Python 3.9 or later. Download the code from GitHub
|
| 125 |
+
|
| 126 |
+
```shell
|
| 127 |
+
$ git clone https://github.com/google-research/scenic.git
|
| 128 |
+
$ cd scenic
|
| 129 |
+
$ pip install .
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
and run training for ViT on ImageNet:
|
| 133 |
+
|
| 134 |
+
```shell
|
| 135 |
+
$ python scenic/main.py -- \
|
| 136 |
+
--config=scenic/projects/baselines/configs/imagenet/imagenet_vit_config.py \
|
| 137 |
+
--workdir=./
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
Note that for specific projects and baselines, you might need to install extra
|
| 141 |
+
packages that are mentioned in their `README.md` or `requirements.txt` files.
|
| 142 |
+
|
| 143 |
+
[Here](https://colab.research.google.com/github/google-research/scenic/blob/main/scenic/common_lib/colabs/scenic_playground.ipynb)
|
| 144 |
+
is also a minimal colab to train a simple feed-forward model using Scenic.
|
| 145 |
+
|
| 146 |
+
<a name="code_structure"></a>
|
| 147 |
+
## Scenic component design
|
| 148 |
+
Scenic is designed to propose different levels of abstraction, to support
|
| 149 |
+
hosting projects that only require changing hyper-parameters by defining config
|
| 150 |
+
files, to those that need customization on the input pipeline, model
|
| 151 |
+
architecture, losses and metrics, and the training loop. To make this happen,
|
| 152 |
+
the code in Scenic is organized as either _project-level_ code,
|
| 153 |
+
which refers to customized code for specific projects or baselines or
|
| 154 |
+
_library-level_ code, which refers to common functionalities and general
|
| 155 |
+
patterns that are adapted by the majority of projects. The project-level
|
| 156 |
+
code lives in the `projects` directory.
|
| 157 |
+
|
| 158 |
+
<div align="center">
|
| 159 |
+
<img src="https://raw.githubusercontent.com/google-research/scenic/main/images/scenic_design.jpg" width="900" alt="scenic design"></img>
|
| 160 |
+
</div>
|
| 161 |
+
|
| 162 |
+
### Library-level code
|
| 163 |
+
The goal is to keep the library-level code minimal and well-tested and to avoid
|
| 164 |
+
introducing extra abstractions to support minor use-cases. Shared libraries
|
| 165 |
+
provided by *Scenic* are split into:
|
| 166 |
+
|
| 167 |
+
* `dataset_lib`: Implements IO pipelines for loading and pre-processing data
|
| 168 |
+
for common Computer Vision tasks and benchmarks (see "Tasks and Datasets"
|
| 169 |
+
section). All pipelines are designed to be scalable and support multi-host
|
| 170 |
+
and multi-device setups, taking care dividing data among multiple hosts,
|
| 171 |
+
incomplete batches, caching, pre-fetching, etc.
|
| 172 |
+
* `model_lib` : Provides
|
| 173 |
+
* several abstract model interfaces (e.g. `ClassificationModel` or
|
| 174 |
+
`SegmentationModel` in `model_lib.base_models`) with task-specific
|
| 175 |
+
losses and metrics;
|
| 176 |
+
* neural network layers in `model_lib.layers`, focusing on efficient
|
| 177 |
+
implementation of attention and transformer layers;
|
| 178 |
+
* accelerator-friendly implementations of bipartite matching
|
| 179 |
+
algorithms in `model_lib.matchers`.
|
| 180 |
+
* `train_lib`: Provides tools for constructing training loops and implements
|
| 181 |
+
several optimized trainers (classification trainer and segmentation trainer)
|
| 182 |
+
that can be forked for customization.
|
| 183 |
+
* `common_lib`: General utilities, like logging and debugging modules,
|
| 184 |
+
functionalities for processing raw data, etc.
|
| 185 |
+
|
| 186 |
+
### Project-level code
|
| 187 |
+
Scenic supports the development of customized solutions for customized tasks and
|
| 188 |
+
data via the concept of "project". There is no one-fits-all recipe for how much
|
| 189 |
+
code should be re-used by a project. Projects can consist of only configs and
|
| 190 |
+
use the common models, trainers, task/data that live in library-level code, or
|
| 191 |
+
they can simply fork any of the mentioned functionalities and redefine, layers,
|
| 192 |
+
losses, metrics, logging methods, tasks, architectures, as well as training and
|
| 193 |
+
evaluation loops. The modularity of library-level code makes it flexible for
|
| 194 |
+
projects to fall placed on any spot in the "run-as-is" to "fully customized"
|
| 195 |
+
spectrum.
|
| 196 |
+
|
| 197 |
+
Common baselines such as a ResNet and Vision Transformer (ViT) are implemented
|
| 198 |
+
in the [`projects/baselines`](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines)
|
| 199 |
+
project. Forking models in this directory is a good starting point for new
|
| 200 |
+
projects.
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
## Citing Scenic
|
| 204 |
+
If you use Scenic, you can cite our [white paper](https://openaccess.thecvf.com/content/CVPR2022/html/Dehghani_Scenic_A_JAX_Library_for_Computer_Vision_Research_and_Beyond_CVPR_2022_paper.html).
|
| 205 |
+
Here is an example BibTeX entry:
|
| 206 |
+
|
| 207 |
+
```bibtex
|
| 208 |
+
@InProceedings{dehghani2021scenic,
|
| 209 |
+
author = {Dehghani, Mostafa and Gritsenko, Alexey and Arnab, Anurag and Minderer, Matthias and Tay, Yi},
|
| 210 |
+
title = {Scenic: A JAX Library for Computer Vision Research and Beyond},
|
| 211 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 212 |
+
year = {2022},
|
| 213 |
+
pages = {21393-21398}
|
| 214 |
+
}
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
_Disclaimer: This is not an official Google product._
|
__pycache__/owlv2_helper.cpython-310.pyc
ADDED
|
Binary file (4.22 kB). View file
|
|
|
__pycache__/owlv2_helper_functions.cpython-310.pyc
ADDED
|
Binary file (7.51 kB). View file
|
|
|
auto_bbox.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import cv2
|
| 4 |
+
import json
|
| 5 |
+
import glob
|
| 6 |
+
import argparse
|
| 7 |
+
import subprocess
|
| 8 |
+
from typing import List, Tuple, Dict, Any
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ----------------- Args -----------------
|
| 15 |
+
def parse_args():
|
| 16 |
+
ap = argparse.ArgumentParser("OWLv2 detection on JPG folders (Top-K per image), multi-GPU.")
|
| 17 |
+
ap.add_argument("--input_dir", type=str, required=True, help="Root that contains subfolders of JPGs; if JPGs are directly under input_dir, it will be treated as a single set.")
|
| 18 |
+
ap.add_argument("--startswith", type=str, default="", help="Filter folder name prefix (or input_dir basename if no subfolders).")
|
| 19 |
+
ap.add_argument("--output_dir", type=str, required=True)
|
| 20 |
+
ap.add_argument("--frame_stride", type=int, default=1, help="Sample every N-th image within a folder.")
|
| 21 |
+
ap.add_argument("--top_k", type=int, default=5)
|
| 22 |
+
ap.add_argument("--max_frames", type=int, default=0, help="Max processed images per folder; 0 means no limit.")
|
| 23 |
+
ap.add_argument("--num_workers", type=int, default=1, help="#GPUs/#workers")
|
| 24 |
+
ap.add_argument("--worker_idx", type=int, default=-1, help="internal; >=0 means child worker")
|
| 25 |
+
ap.add_argument("--shard_file", type=str, default="", help="internal; JSON with folder paths for this worker")
|
| 26 |
+
ap.add_argument("--scenic_root", type=str, default="/home/ubuntu/rs/JiT/VisionModels/Scenic_OWLv2/big_vision")
|
| 27 |
+
return ap.parse_args()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ----------------- Utils -----------------
|
| 31 |
+
def _has_jpgs(path: str) -> bool:
|
| 32 |
+
exts = ("*.jpg", "*.jpeg", "*.JPG", "*.JPEG")
|
| 33 |
+
for pat in exts:
|
| 34 |
+
if glob.glob(os.path.join(path, pat)):
|
| 35 |
+
return True
|
| 36 |
+
return False
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def iter_image_dirs(input_dir: str, startswith: str) -> List[str]:
|
| 40 |
+
"""
|
| 41 |
+
Returns a list of directories to process.
|
| 42 |
+
- If input_dir contains subfolders: return subfolders that contain JPGs and match startswith.
|
| 43 |
+
- Else if input_dir itself contains JPGs and its basename matches startswith: return [input_dir].
|
| 44 |
+
"""
|
| 45 |
+
input_dir = os.path.abspath(input_dir)
|
| 46 |
+
subs = sorted([p for p in glob.glob(os.path.join(input_dir, "*")) if os.path.isdir(p)])
|
| 47 |
+
# Prefer subfolders if any exist and contain jpgs
|
| 48 |
+
dirs = [d for d in subs if os.path.basename(d).startswith(startswith) and _has_jpgs(d)]
|
| 49 |
+
if dirs:
|
| 50 |
+
return dirs
|
| 51 |
+
|
| 52 |
+
# Fallback: treat input_dir itself as one set if it has jpgs
|
| 53 |
+
base_ok = os.path.basename(os.path.normpath(input_dir)).startswith(startswith)
|
| 54 |
+
if base_ok and _has_jpgs(input_dir):
|
| 55 |
+
return [input_dir]
|
| 56 |
+
return []
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def ensure_dir(p: str):
|
| 60 |
+
os.makedirs(p, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def draw_single_box(frame_bgr: np.ndarray, box: List[float], color=(0, 255, 0), thickness=2) -> np.ndarray:
|
| 64 |
+
x1, y1, x2, y2 = map(int, box)
|
| 65 |
+
out = frame_bgr.copy()
|
| 66 |
+
cv2.rectangle(out, (x1, y1), (x2, y2), color, thickness)
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def list_images_sorted(folder: str) -> List[str]:
|
| 71 |
+
pats = ["*.jpg", "*.jpeg", "*.JPG", "*.JPEG"]
|
| 72 |
+
files = []
|
| 73 |
+
for pat in pats:
|
| 74 |
+
files.extend(glob.glob(os.path.join(folder, pat)))
|
| 75 |
+
# Sort by natural file name order
|
| 76 |
+
return sorted(files)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ----------------- Worker logic (imports JAX/Scenic inside) -----------------
|
| 80 |
+
def worker_run(args, dir_paths: List[str]):
|
| 81 |
+
import sys as _sys
|
| 82 |
+
if args.scenic_root not in _sys.path:
|
| 83 |
+
_sys.path.append(args.scenic_root)
|
| 84 |
+
|
| 85 |
+
# Free TF GPU to JAX in this process (why: avoid TF reserving VRAM)
|
| 86 |
+
import tensorflow as tf
|
| 87 |
+
tf.config.experimental.set_visible_devices([], "GPU")
|
| 88 |
+
|
| 89 |
+
from scenic.projects.owl_vit import configs
|
| 90 |
+
from scenic.projects.owl_vit import models
|
| 91 |
+
import jax
|
| 92 |
+
import functools
|
| 93 |
+
import owlv2_helper as helper # must be available in PYTHONPATH
|
| 94 |
+
|
| 95 |
+
class OWLv2Objectness:
|
| 96 |
+
def __init__(self, top_k: int = 5):
|
| 97 |
+
self.top_k = top_k
|
| 98 |
+
self.config = configs.owl_v2_clip_b16.get_config(init_mode="canonical_checkpoint")
|
| 99 |
+
self.module = models.TextZeroShotDetectionModule(
|
| 100 |
+
body_configs=self.config.model.body,
|
| 101 |
+
objectness_head_configs=self.config.model.objectness_head,
|
| 102 |
+
normalize=self.config.model.normalize,
|
| 103 |
+
box_bias=self.config.model.box_bias,
|
| 104 |
+
)
|
| 105 |
+
self.variables = self.module.load_variables(self.config.init_from.checkpoint_path)
|
| 106 |
+
|
| 107 |
+
self.image_embedder = jax.jit(
|
| 108 |
+
functools.partial(self.module.apply, self.variables, train=False, method=self.module.image_embedder)
|
| 109 |
+
)
|
| 110 |
+
self.objectness_predictor = jax.jit(
|
| 111 |
+
functools.partial(self.module.apply, self.variables, method=self.module.objectness_predictor)
|
| 112 |
+
)
|
| 113 |
+
self.box_predictor = jax.jit(
|
| 114 |
+
functools.partial(self.module.apply, self.variables, method=self.module.box_predictor)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def detect(self, image_bgr: np.ndarray) -> List[Tuple[List[float], float]]:
|
| 118 |
+
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
| 119 |
+
processed = helper.preprocess_images([image_rgb], self.config.dataset_configs.input_size)[0]
|
| 120 |
+
feature_map = self.image_embedder(processed[None, ...])
|
| 121 |
+
b, h, w, d = feature_map.shape
|
| 122 |
+
image_features = feature_map.reshape(b, h * w, d)
|
| 123 |
+
|
| 124 |
+
obj_logits = self.objectness_predictor(image_features)["objectness_logits"]
|
| 125 |
+
raw_boxes = self.box_predictor(image_features=image_features, feature_map=feature_map)["pred_boxes"]
|
| 126 |
+
|
| 127 |
+
obj = np.array(obj_logits[0], dtype=np.float32)
|
| 128 |
+
raw_boxes = np.array(raw_boxes[0], dtype=np.float32)
|
| 129 |
+
boxes = helper.rescale_detection_box(raw_boxes, image_rgb)
|
| 130 |
+
|
| 131 |
+
if len(obj) == 0:
|
| 132 |
+
return []
|
| 133 |
+
|
| 134 |
+
k = min(self.top_k, len(obj))
|
| 135 |
+
thresh = np.partition(obj, -k)[-k]
|
| 136 |
+
|
| 137 |
+
filtered: List[Tuple[List[float], float]] = []
|
| 138 |
+
H, W = image_rgb.shape[:2]
|
| 139 |
+
for box, score in zip(boxes, obj):
|
| 140 |
+
if score < thresh:
|
| 141 |
+
continue
|
| 142 |
+
if helper.too_small(box) or helper.too_large(box, image_rgb):
|
| 143 |
+
continue
|
| 144 |
+
x1, y1, x2, y2 = box
|
| 145 |
+
x1 = max(0, min(float(x1), W - 1))
|
| 146 |
+
y1 = max(0, min(float(y1), H - 1))
|
| 147 |
+
x2 = max(0, min(float(x2), W - 1))
|
| 148 |
+
y2 = max(0, min(float(y2), H - 1))
|
| 149 |
+
filtered.append(([x1, y1, x2, y2], float(score)))
|
| 150 |
+
|
| 151 |
+
kept_boxes = helper.remove_overlapping_bboxes([b for b, _ in filtered])
|
| 152 |
+
|
| 153 |
+
def _match_score(bb: List[float]) -> float:
|
| 154 |
+
arr = np.array([b for b, _ in filtered], dtype=np.float32)
|
| 155 |
+
idx = int(np.argmin(np.abs(arr - np.array(bb, dtype=np.float32)).sum(axis=1)))
|
| 156 |
+
return filtered[idx][1]
|
| 157 |
+
|
| 158 |
+
return [(bb, _match_score(bb)) for bb in kept_boxes]
|
| 159 |
+
|
| 160 |
+
detector = OWLv2Objectness(top_k=args.top_k)
|
| 161 |
+
|
| 162 |
+
for dpath in tqdm(dir_paths, desc=f"Worker{args.worker_idx}", unit="set"):
|
| 163 |
+
stem = os.path.basename(os.path.normpath(dpath))
|
| 164 |
+
images = list_images_sorted(dpath)
|
| 165 |
+
if not images:
|
| 166 |
+
print(f"[WARN][w{args.worker_idx}] No JPGs under: {dpath}")
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
saved_cnt = 0
|
| 170 |
+
pbar = tqdm(total=len(images), desc=f"{stem}[w{args.worker_idx}]", unit="img", leave=False)
|
| 171 |
+
|
| 172 |
+
for idx, ipath in enumerate(images):
|
| 173 |
+
pbar.update(1)
|
| 174 |
+
if args.frame_stride > 1 and (idx % args.frame_stride) != 0:
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
frame = cv2.imread(ipath, cv2.IMREAD_COLOR)
|
| 178 |
+
if frame is None:
|
| 179 |
+
print(f"[WARN][w{args.worker_idx}] Cannot read: {ipath}")
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
boxes_scores = detector.detect(frame)
|
| 183 |
+
if boxes_scores:
|
| 184 |
+
boxes_scores = sorted(boxes_scores, key=lambda x: x[1], reverse=True)[:args.top_k]
|
| 185 |
+
|
| 186 |
+
fname = os.path.basename(ipath)
|
| 187 |
+
for i, (box, score) in enumerate(boxes_scores):
|
| 188 |
+
out_dir = os.path.join(args.output_dir, stem, f"object_{i}")
|
| 189 |
+
ensure_dir(out_dir)
|
| 190 |
+
vis = draw_single_box(frame, box, color=(0, 255, 0), thickness=2)
|
| 191 |
+
cv2.imwrite(os.path.join(out_dir, fname), vis)
|
| 192 |
+
|
| 193 |
+
saved_cnt += 1
|
| 194 |
+
if args.max_frames and saved_cnt >= args.max_frames:
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
pbar.close()
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# ----------------- Master -----------------
|
| 201 |
+
def main():
|
| 202 |
+
args = parse_args()
|
| 203 |
+
|
| 204 |
+
# Child worker path
|
| 205 |
+
if args.worker_idx >= 0:
|
| 206 |
+
if not args.shard_file or not os.path.exists(args.shard_file):
|
| 207 |
+
raise RuntimeError("Worker requires --shard_file with JSON list of folder paths.")
|
| 208 |
+
with open(args.shard_file, "r", encoding="utf-8") as f:
|
| 209 |
+
dir_paths = json.load(f)
|
| 210 |
+
worker_run(args, dir_paths)
|
| 211 |
+
return
|
| 212 |
+
|
| 213 |
+
# Master path
|
| 214 |
+
dir_paths = iter_image_dirs(args.input_dir, args.startswith)
|
| 215 |
+
if not dir_paths:
|
| 216 |
+
print(f"[INFO] No JPG folders (or JPGs) startwith '{args.startswith}' under {args.input_dir}")
|
| 217 |
+
return
|
| 218 |
+
|
| 219 |
+
num_workers = max(1, int(args.num_workers))
|
| 220 |
+
shards: List[List[str]] = [[] for _ in range(num_workers)]
|
| 221 |
+
for i, d in enumerate(dir_paths):
|
| 222 |
+
shards[i % num_workers].append(d)
|
| 223 |
+
|
| 224 |
+
procs = []
|
| 225 |
+
tmp_dir = os.path.join(args.output_dir, "_shards_tmp")
|
| 226 |
+
ensure_dir(tmp_dir)
|
| 227 |
+
|
| 228 |
+
for w in range(num_workers):
|
| 229 |
+
shard_path = os.path.join(tmp_dir, f"shard_{w}.json")
|
| 230 |
+
with open(shard_path, "w", encoding="utf-8") as f:
|
| 231 |
+
json.dump(shards[w], f, ensure_ascii=False, indent=0)
|
| 232 |
+
|
| 233 |
+
# Bind GPU: cycle through available GPU ids [0..num_workers-1]
|
| 234 |
+
env = os.environ.copy()
|
| 235 |
+
env["CUDA_VISIBLE_DEVICES"] = str(w) # one GPU per worker
|
| 236 |
+
|
| 237 |
+
cmd = [
|
| 238 |
+
sys.executable, __file__,
|
| 239 |
+
"--input_dir", args.input_dir,
|
| 240 |
+
"--startswith", args.startswith,
|
| 241 |
+
"--output_dir", args.output_dir,
|
| 242 |
+
"--frame_stride", str(args.frame_stride),
|
| 243 |
+
"--top_k", str(args.top_k),
|
| 244 |
+
"--max_frames", str(args.max_frames),
|
| 245 |
+
"--num_workers", str(num_workers),
|
| 246 |
+
"--worker_idx", str(w),
|
| 247 |
+
"--shard_file", shard_path,
|
| 248 |
+
"--scenic_root", args.scenic_root,
|
| 249 |
+
]
|
| 250 |
+
print(f"[Master] Launch worker {w}: GPU={env['CUDA_VISIBLE_DEVICES']} folders={len(shards[w])}")
|
| 251 |
+
procs.append(subprocess.Popen(cmd, env=env))
|
| 252 |
+
|
| 253 |
+
# wait
|
| 254 |
+
rc = 0
|
| 255 |
+
for p in procs:
|
| 256 |
+
p.wait()
|
| 257 |
+
rc |= p.returncode
|
| 258 |
+
|
| 259 |
+
if rc != 0:
|
| 260 |
+
print("[Master] Some workers failed. Return code:", rc)
|
| 261 |
+
else:
|
| 262 |
+
print("[Master] All workers done. Output:", args.output_dir)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == "__main__":
|
| 266 |
+
main()
|
big_vision/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
big_vision/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# How to Contribute
|
| 2 |
+
|
| 3 |
+
At this time we do not plan to accept non-trivial contributions. The main
|
| 4 |
+
purpose of this codebase is to allow the community to reproduce results from our
|
| 5 |
+
publications.
|
| 6 |
+
|
| 7 |
+
You are however free to start a fork of the project for your purposes as
|
| 8 |
+
permitted by the license.
|
| 9 |
+
|
| 10 |
+
## Contributor License Agreement
|
| 11 |
+
|
| 12 |
+
Contributions to this project must be accompanied by a Contributor License
|
| 13 |
+
Agreement (CLA). You (or your employer) retain the copyright to your
|
| 14 |
+
contribution; this simply gives us permission to use and redistribute your
|
| 15 |
+
contributions as part of the project. Head over to
|
| 16 |
+
<https://cla.developers.google.com/> to see your current agreements on file or
|
| 17 |
+
to sign a new one.
|
| 18 |
+
|
| 19 |
+
You generally only need to submit a CLA once, so if you've already submitted one
|
| 20 |
+
(even if it was for a different project), you probably don't need to do it
|
| 21 |
+
again.
|
| 22 |
+
|
| 23 |
+
## Community Guidelines
|
| 24 |
+
|
| 25 |
+
This project follows
|
| 26 |
+
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
|
big_vision/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
big_vision/README.md
ADDED
|
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Big Vision
|
| 2 |
+
|
| 3 |
+
This codebase is designed for training large-scale vision models using
|
| 4 |
+
[Cloud TPU VMs](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms)
|
| 5 |
+
or GPU machines. It is based on [Jax](https://github.com/google/jax)/[Flax](https://github.com/google/flax)
|
| 6 |
+
libraries, and uses [tf.data](https://www.tensorflow.org/guide/data) and
|
| 7 |
+
[TensorFlow Datasets](https://www.tensorflow.org/datasets) for scalable and
|
| 8 |
+
reproducible input pipelines.
|
| 9 |
+
|
| 10 |
+
The open-sourcing of this codebase has two main purposes:
|
| 11 |
+
1. Publishing the code of research projects developed in this codebase (see a
|
| 12 |
+
list below).
|
| 13 |
+
2. Providing a strong starting point for running large-scale vision experiments
|
| 14 |
+
on GPU machines and Google Cloud TPUs, which should scale seamlessly and
|
| 15 |
+
out-of-the box from a single TPU core to a distributed setup with up to 2048
|
| 16 |
+
TPU cores.
|
| 17 |
+
|
| 18 |
+
`big_vision` aims to support research projects at Google. We are unlikely to
|
| 19 |
+
work on feature requests or accept external contributions, unless they were
|
| 20 |
+
pre-approved (ask in an issue first). For a well-supported transfer-only
|
| 21 |
+
codebase, see also [vision_transformer](https://github.com/google-research/vision_transformer).
|
| 22 |
+
|
| 23 |
+
Note that `big_vision` is quite dynamic codebase and, while we intend to keep
|
| 24 |
+
the core code fully-functional at all times, we can not guarantee timely updates
|
| 25 |
+
of the project-specific code that lives in the `.../proj/...` subfolders.
|
| 26 |
+
However, we provide a [table](#project-specific-commits) with last known
|
| 27 |
+
commits where specific projects were known to work.
|
| 28 |
+
|
| 29 |
+
The following research projects were originally conducted in the `big_vision`
|
| 30 |
+
codebase:
|
| 31 |
+
|
| 32 |
+
### Architecture research
|
| 33 |
+
|
| 34 |
+
- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929), by
|
| 35 |
+
Alexey Dosovitskiy*, Lucas Beyer*, Alexander Kolesnikov*, Dirk Weissenborn*,
|
| 36 |
+
Xiaohua Zhai*, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer,
|
| 37 |
+
Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby*
|
| 38 |
+
- [Scaling Vision Transformers](https://arxiv.org/abs/2106.04560), by
|
| 39 |
+
Xiaohua Zhai*, Alexander Kolesnikov*, Neil Houlsby, and Lucas Beyer*\
|
| 40 |
+
Resources: [config](big_vision/configs/proj/scaling_laws/train_vit_g.py).
|
| 41 |
+
- [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270), by
|
| 42 |
+
Andreas Steiner*, Alexander Kolesnikov*, Xiaohua Zhai*, Ross Wightman,
|
| 43 |
+
Jakob Uszkoreit, and Lucas Beyer*
|
| 44 |
+
- [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601), by
|
| 45 |
+
Ilya Tolstikhin*, Neil Houlsby*, Alexander Kolesnikov*, Lucas Beyer*,
|
| 46 |
+
Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner,
|
| 47 |
+
Daniel Keysers, Jakob Uszkoreit, Mario Lucic, Alexey Dosovitskiy\
|
| 48 |
+
Resources: [config](big_vision/configs/mlp_mixer_i1k.py).
|
| 49 |
+
- [Better plain ViT baselines for ImageNet-1k](https://arxiv.org/abs/2205.01580), by
|
| 50 |
+
Lucas Beyer, Xiaohua Zhai, Alexander Kolesnikov\
|
| 51 |
+
Resources: [config](big_vision/configs/vit_s16_i1k.py)
|
| 52 |
+
- [UViM: A Unified Modeling Approach for Vision with Learned Guiding Codes](https://arxiv.org/abs/2205.10337), by
|
| 53 |
+
Alexander Kolesnikov^*, André Susano Pinto^*, Lucas Beyer*, Xiaohua Zhai*, Jeremiah Harmsen*, Neil Houlsby*\
|
| 54 |
+
Resources: [readme](big_vision/configs/proj/uvim/README.md), [configs](big_vision/configs/proj/uvim), [colabs](big_vision/configs/proj/uvim).
|
| 55 |
+
- [FlexiViT: One Model for All Patch Sizes](https://arxiv.org/abs/2212.08013), by
|
| 56 |
+
Lucas Beyer*, Pavel Izmailov*, Alexander Kolesnikov*, Mathilde Caron*, Simon
|
| 57 |
+
Kornblith*, Xiaohua Zhai*, Matthias Minderer*, Michael Tschannen*, Ibrahim
|
| 58 |
+
Alabdulmohsin*, Filip Pavetic*\
|
| 59 |
+
Resources: [readme](big_vision/configs/proj/flexivit/README.md), [configs](big_vision/configs/proj/flexivit).
|
| 60 |
+
- [Dual PatchNorm](https://arxiv.org/abs/2302.01327), by Manoj Kumar, Mostafa Dehghani, Neil Houlsby.
|
| 61 |
+
- [Getting ViT in Shape: Scaling Laws for Compute-Optimal Model Design](https://arxiv.org/abs/2305.13035), by
|
| 62 |
+
Ibrahim Alabdulmohsin*, Xiaohua Zhai*, Alexander Kolesnikov, Lucas Beyer*.
|
| 63 |
+
- (partial) [Scaling Vision Transformers to 22 Billion Parameters](https://arxiv.org/abs/2302.05442), by
|
| 64 |
+
Mostafa Dehghani*, Josip Djolonga*, Basil Mustafa*, Piotr Padlewski*, Jonathan Heek*, *wow many middle authors*, Neil Houlsby*.
|
| 65 |
+
- (partial) [Finite Scalar Quantization: VQ-VAE Made Simple](https://arxiv.org/abs/2309.15505), by
|
| 66 |
+
Fabian Mentzer, David Minnen, Eirikur Agustsson, Michael Tschannen.
|
| 67 |
+
- [GIVT: Generative Infinite-Vocabulary Transformers](https://arxiv.org/abs/2312.02116), by
|
| 68 |
+
Michael Tschannen, Cian Eastwood, Fabian Mentzer.\
|
| 69 |
+
Resources: [readme](big_vision/configs/proj/givt/README.md), [config](big_vision/configs/proj/givt/givt_imagenet2012.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/givt/givt_demo_colab.ipynb).
|
| 70 |
+
- [Unified Auto-Encoding with Masked Diffusion](https://arxiv.org/abs/2406.17688), by
|
| 71 |
+
Philippe Hansen-Estruch, Sriram Vishwanath, Amy Zhang, Manan Tomar.
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
### Multimodal research
|
| 75 |
+
|
| 76 |
+
- [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991), by
|
| 77 |
+
Xiaohua Zhai*, Xiao Wang*, Basil Mustafa*, Andreas Steiner*, Daniel Keysers,
|
| 78 |
+
Alexander Kolesnikov, and Lucas Beyer*\
|
| 79 |
+
Resources: [trainer](big_vision/trainers/proj/image_text/contrastive.py), [config](big_vision/configs/proj/image_text/lit_coco.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/lit.ipynb).
|
| 80 |
+
- [Image-and-Language Understanding from Pixels Only](https://arxiv.org/abs/2212.08045), by
|
| 81 |
+
Michael Tschannen, Basil Mustafa, Neil Houlsby\
|
| 82 |
+
Resources: [readme](big_vision/configs/proj/clippo/README.md), [config](big_vision/configs/proj/clippo/train_clippo.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/clippo/clippo_colab.ipynb).
|
| 83 |
+
- [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343), by
|
| 84 |
+
Xiaohua Zhai*, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer*\
|
| 85 |
+
Resources: [colab and models](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb), code TODO.
|
| 86 |
+
- [A Study of Autoregressive Decoders for Multi-Tasking in Computer Vision](https://arxiv.org/abs/2303.17376), by
|
| 87 |
+
Lucas Beyer*, Bo Wan*, Gagan Madan*, Filip Pavetic*, Andreas Steiner*, Alexander Kolesnikov, André Susano Pinto, Emanuele Bugliarello, Xiao Wang, Qihang Yu, Liang-Chieh Chen, Xiaohua Zhai*.
|
| 88 |
+
- [Image Captioners Are Scalable Vision Learners Too](https://arxiv.org/abs/2306.07915), by
|
| 89 |
+
Michael Tschannen*, Manoj Kumar*, Andreas Steiner*, Xiaohua Zhai, Neil Houlsby, Lucas Beyer*.\
|
| 90 |
+
Resources: [readme](big_vision/configs/proj/cappa/README.md), [config](big_vision/configs/proj/cappa/pretrain.py), [model](big_vision/models/proj/cappa/cappa.py).
|
| 91 |
+
- [Three Towers: Flexible Contrastive Learning with Pretrained Image Models](https://arxiv.org/abs/2305.16999), by Jannik Kossen, Mark Collier, Basil Mustafa, Xiao Wang, Xiaohua Zhai, Lucas Beyer, Andreas Steiner, Jesse Berent, Rodolphe Jenatton, Efi Kokiopoulou.
|
| 92 |
+
- (partial) [PaLI: A Jointly-Scaled Multilingual Language-Image Model](https://arxiv.org/abs/2209.06794), by Xi Chen, Xiao Wang, Soravit Changpinyo, *wow so many middle authors*, Anelia Angelova, Xiaohua Zhai, Neil Houlsby, Radu Soricut.
|
| 93 |
+
- (partial) [PaLI-3 Vision Language Models: Smaller, Faster, Stronger](https://arxiv.org/abs/2310.09199), by Xi Chen, Xiao Wang, Lucas Beyer, Alexander Kolesnikov, Jialin Wu, Paul Voigtlaender, Basil Mustafa, Sebastian Goodman, Ibrahim Alabdulmohsin, Piotr Padlewski, Daniel Salz, Xi Xiong, Daniel Vlasic, Filip Pavetic, Keran Rong, Tianli Yu, Daniel Keysers, Xiaohua Zhai, Radu Soricut.
|
| 94 |
+
- [LocCa](https://arxiv.org/abs/2403.19596), by
|
| 95 |
+
Bo Wan, Michael Tschannen, Yongqin Xian, Filip Pavetic, Ibrahim Alabdulmohsin, Xiao Wang, André Susano Pinto, Andreas Steiner, Lucas Beyer, Xiaohua Zhai.
|
| 96 |
+
- [PaliGemma](https://arxiv.org/abs/2407.07726),
|
| 97 |
+
[PaliGemma 2](https://arxiv.org/abs/2412.03555), by *wow many authors*.\
|
| 98 |
+
- Resources: [readme](big_vision/configs/proj/paligemma/README.md),
|
| 99 |
+
[model](big_vision/models/proj/paligemma/paligemma.py),
|
| 100 |
+
[transfer configs](big_vision/configs/proj/paligemma/transfers),
|
| 101 |
+
[datasets](big_vision/datasets),
|
| 102 |
+
[CountBenchQA](big_vision/datasets/countbenchqa/data/countbench_paired_questions.json).
|
| 103 |
+
|
| 104 |
+
### Training
|
| 105 |
+
|
| 106 |
+
- [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237), by
|
| 107 |
+
Lucas Beyer*, Xiaohua Zhai*, Amélie Royer*, Larisa Markeeva*, Rohan Anil,
|
| 108 |
+
and Alexander Kolesnikov*\
|
| 109 |
+
Resources: [README](big_vision/configs/proj/distill/README.md), [trainer](big_vision/trainers/proj/distill/distill.py), [colab](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing).
|
| 110 |
+
- [Sharpness-Aware Minimization for Efficiently Improving Generalization](https://arxiv.org/abs/2010.01412), by
|
| 111 |
+
Pierre Foret, Ariel Kleiner, Hossein Mobahi, Behnam Neyshabur
|
| 112 |
+
- [Surrogate Gap Minimization Improves Sharpness-Aware Training](https://arxiv.org/abs/2203.08065), by Juntang Zhuang, Boqing Gong, Liangzhe Yuan, Yin Cui, Hartwig Adam, Nicha Dvornek, Sekhar Tatikonda, James Duncan and Ting Liu \
|
| 113 |
+
Resources: [trainer](big_vision/trainers/proj/gsam/gsam.py), [config](big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py) [reproduced results](https://github.com/google-research/big_vision/pull/8#pullrequestreview-1078557411)
|
| 114 |
+
- [Tuning computer vision models with task rewards](https://arxiv.org/abs/2302.08242), by
|
| 115 |
+
André Susano Pinto*, Alexander Kolesnikov*, Yuge Shi, Lucas Beyer, Xiaohua Zhai.
|
| 116 |
+
- (partial) [VeLO: Training Versatile Learned Optimizers by Scaling Up](https://arxiv.org/abs/2211.09760) by
|
| 117 |
+
Luke Metz, James Harrison, C. Daniel Freeman, Amil Merchant, Lucas Beyer, James Bradbury, Naman Agrawal, Ben Poole, Igor Mordatch, Adam Roberts, Jascha Sohl-Dickstein.
|
| 118 |
+
|
| 119 |
+
### Misc
|
| 120 |
+
|
| 121 |
+
- [Are we done with ImageNet?](https://arxiv.org/abs/2006.07159), by
|
| 122 |
+
Lucas Beyer*, Olivier J. Hénaff*, Alexander Kolesnikov*, Xiaohua Zhai*, Aäron van den Oord*.
|
| 123 |
+
- [No Filter: Cultural and Socioeconomic Diversity in Contrastive Vision-Language Models](https://arxiv.org/abs/2405.13777), by
|
| 124 |
+
Angéline Pouget, Lucas Beyer, Emanuele Bugliarello, Xiao Wang, Andreas Peter Steiner, Xiaohua Zhai, Ibrahim Alabdulmohsin.
|
| 125 |
+
|
| 126 |
+
# Codebase high-level organization and principles in a nutshell
|
| 127 |
+
|
| 128 |
+
The main entry point is a trainer module, which typically does all the
|
| 129 |
+
boilerplate related to creating a model and an optimizer, loading the data,
|
| 130 |
+
checkpointing and training/evaluating the model inside a loop. We provide the
|
| 131 |
+
canonical trainer `train.py` in the root folder. Normally, individual projects
|
| 132 |
+
within `big_vision` fork and customize this trainer.
|
| 133 |
+
|
| 134 |
+
All models, evaluators and preprocessing operations live in the corresponding
|
| 135 |
+
subdirectories and can often be reused between different projects. We encourage
|
| 136 |
+
compatible APIs within these directories to facilitate reusability, but it is
|
| 137 |
+
not strictly enforced, as individual projects may need to introduce their custom
|
| 138 |
+
APIs.
|
| 139 |
+
|
| 140 |
+
We have a powerful configuration system, with the configs living in the
|
| 141 |
+
`configs/` directory. Custom trainers and modules can directly extend/modify
|
| 142 |
+
the configuration options.
|
| 143 |
+
|
| 144 |
+
Project-specific code resides in the `.../proj/...` namespace. It is not always
|
| 145 |
+
possible to keep project-specific in sync with the core `big_vision` libraries,
|
| 146 |
+
Below we provide the [last known commit](#project-specific-commits)
|
| 147 |
+
for each project where the project code is expected to work.
|
| 148 |
+
|
| 149 |
+
Training jobs are robust to interruptions and will resume seamlessly from the
|
| 150 |
+
last saved checkpoint (assuming a user provides the correct `--workdir` path).
|
| 151 |
+
|
| 152 |
+
Each configuration file contains a comment at the top with a `COMMAND` snippet
|
| 153 |
+
to run it, and some hint of expected runtime and results. See below for more
|
| 154 |
+
details, but generally speaking, running on a GPU machine involves calling
|
| 155 |
+
`python -m COMMAND` while running on TPUs, including multi-host, involves
|
| 156 |
+
|
| 157 |
+
```
|
| 158 |
+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all
|
| 159 |
+
--command "bash big_vision/run_tpu.sh COMMAND"
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
See instructions below for more details on how to run `big_vision` code on a
|
| 163 |
+
GPU machine or Google Cloud TPU.
|
| 164 |
+
|
| 165 |
+
By default we write checkpoints and logfiles. The logfiles are a list of JSON
|
| 166 |
+
objects, and we provide a short and straightforward [example colab to read
|
| 167 |
+
and display the logs and checkpoints](https://colab.research.google.com/drive/1R_lvV542WUp8Q2y8sbyooZOGCplkn7KI?usp=sharing).
|
| 168 |
+
|
| 169 |
+
# Current and future contents
|
| 170 |
+
|
| 171 |
+
The first release contains the core part of pre-training, transferring, and
|
| 172 |
+
evaluating classification models at scale on Cloud TPU VMs.
|
| 173 |
+
|
| 174 |
+
We have since added the following key features and projects:
|
| 175 |
+
- Contrastive Image-Text model training and evaluation as in LiT and CLIP.
|
| 176 |
+
- Patient and consistent distillation.
|
| 177 |
+
- Scaling ViT.
|
| 178 |
+
- MLP-Mixer.
|
| 179 |
+
- UViM.
|
| 180 |
+
|
| 181 |
+
Features and projects we plan to release in the near future, in no particular
|
| 182 |
+
order:
|
| 183 |
+
- ImageNet-21k in TFDS.
|
| 184 |
+
- Loading misc public models used in our publications (NFNet, MoCov3, DINO).
|
| 185 |
+
- Memory-efficient Polyak-averaging implementation.
|
| 186 |
+
- Advanced JAX compute and memory profiling. We are using internal tools for
|
| 187 |
+
this, but may eventually add support for the publicly available ones.
|
| 188 |
+
|
| 189 |
+
We will continue releasing code of our future publications developed within
|
| 190 |
+
`big_vision` here.
|
| 191 |
+
|
| 192 |
+
### Non-content
|
| 193 |
+
|
| 194 |
+
The following exist in the internal variant of this codebase, and there is no
|
| 195 |
+
plan for their release:
|
| 196 |
+
- Regular regression tests for both quality and speed. They rely heavily on
|
| 197 |
+
internal infrastructure.
|
| 198 |
+
- Advanced logging, monitoring, and plotting of experiments. This also relies
|
| 199 |
+
heavily on internal infrastructure. However, we are open to ideas on this
|
| 200 |
+
and may add some in the future, especially if implemented in a
|
| 201 |
+
self-contained manner.
|
| 202 |
+
- Not yet published, ongoing research projects.
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# GPU Setup
|
| 206 |
+
|
| 207 |
+
We first discuss how to setup and run `big_vision` on a (local) GPU machine,
|
| 208 |
+
and then discuss the setup for Cloud TPUs. Note that data preparation step for
|
| 209 |
+
(local) GPU setup can be largely reused for the Cloud TPU setup. While the
|
| 210 |
+
instructions skip this for brevity, we highly recommend using a
|
| 211 |
+
[virtual environment](https://docs.python.org/3/library/venv.html) when
|
| 212 |
+
installing python dependencies.
|
| 213 |
+
|
| 214 |
+
## Setting up python packages
|
| 215 |
+
|
| 216 |
+
The first step is to checkout `big_vision` and install relevant python
|
| 217 |
+
dependencies:
|
| 218 |
+
|
| 219 |
+
```
|
| 220 |
+
git clone https://github.com/google-research/big_vision
|
| 221 |
+
cd big_vision/
|
| 222 |
+
pip3 install --upgrade pip
|
| 223 |
+
pip3 install -r big_vision/requirements.txt
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
The latest version of `jax` library can be fetched as
|
| 227 |
+
|
| 228 |
+
```
|
| 229 |
+
pip3 install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
You may need a different `jax` package, depending on CUDA and cuDNN libraries
|
| 233 |
+
installed on your machine. Please consult
|
| 234 |
+
[official jax documentation](https://github.com/google/jax#pip-installation-gpu-cuda)
|
| 235 |
+
for more information.
|
| 236 |
+
|
| 237 |
+
## Preparing tfds data
|
| 238 |
+
|
| 239 |
+
For unified and reproducible access to standard datasets we opted to use the
|
| 240 |
+
`tensorflow_datasets` (`tfds`) library. It requires each dataset to be
|
| 241 |
+
downloaded, preprocessed and then to be stored on a hard drive (or, if you use
|
| 242 |
+
"Google Cloud", preferably stored in a "GCP bucket".).
|
| 243 |
+
|
| 244 |
+
Many datasets can be downloaded and preprocessed automatically when used
|
| 245 |
+
for the first time. Nevertheless, we intentionally disable this feature and
|
| 246 |
+
recommend doing dataset preparation step separately, ahead of the first run. It
|
| 247 |
+
will make debugging easier if problems arise and some datasets, like
|
| 248 |
+
`imagenet2012`, require manually downloaded data.
|
| 249 |
+
|
| 250 |
+
Most of the datasets, e.g. `cifar100`, `oxford_iiit_pet` or `imagenet_v2`
|
| 251 |
+
can be fully automatically downloaded and prepared by running
|
| 252 |
+
|
| 253 |
+
```
|
| 254 |
+
cd big_vision/
|
| 255 |
+
python3 -m big_vision.tools.download_tfds_datasets cifar100 oxford_iiit_pet imagenet_v2
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
A full list of datasets is available at [this link](https://www.tensorflow.org/datasets/catalog/overview#all_datasets).
|
| 259 |
+
|
| 260 |
+
Some datasets, like `imagenet2012` or `imagenet2012_real`, require the data to
|
| 261 |
+
be downloaded manually and placed into `$TFDS_DATA_DIR/downloads/manual/`,
|
| 262 |
+
which defaults to `~/tensorflow_datasets/downloads/manual/`. For example, for
|
| 263 |
+
`imagenet2012` and `imagenet2012_real` one needs to place the official
|
| 264 |
+
`ILSVRC2012_img_train.tar` and `ILSVRC2012_img_val.tar` files in that directory
|
| 265 |
+
and then run
|
| 266 |
+
`python3 -m big_vision.tools.download_tfds_datasets imagenet2012 imagenet2012_real`
|
| 267 |
+
(which may take ~1 hour).
|
| 268 |
+
|
| 269 |
+
If you use `Google Cloud` and, TPUs in particular, you can then upload
|
| 270 |
+
the preprocessed data (stored in `$TFDS_DATA_DIR`) to
|
| 271 |
+
"Google Cloud Bucket" and use the bucket on any of your (TPU) virtual
|
| 272 |
+
machines to access the data.
|
| 273 |
+
|
| 274 |
+
## Running on a GPU machine
|
| 275 |
+
|
| 276 |
+
Finally, after installing all python dependencies and preparing `tfds` data,
|
| 277 |
+
the user can run the job using config of their choice, e.g. to train `ViT-S/16`
|
| 278 |
+
model on ImageNet data, one should run the following command:
|
| 279 |
+
|
| 280 |
+
```
|
| 281 |
+
python3 -m big_vision.train --config big_vision/configs/vit_s16_i1k.py --workdir workdirs/`date '+%m-%d_%H%M'`
|
| 282 |
+
```
|
| 283 |
+
|
| 284 |
+
or to train MLP-Mixer-B/16, run (note the `gpu8` config param that reduces the default batch size and epoch count):
|
| 285 |
+
|
| 286 |
+
```
|
| 287 |
+
python3 -m big_vision.train --config big_vision/configs/mlp_mixer_i1k.py:gpu8 --workdir workdirs/`date '+%m-%d_%H%M'`
|
| 288 |
+
```
|
| 289 |
+
|
| 290 |
+
# Cloud TPU VM setup
|
| 291 |
+
|
| 292 |
+
## Create TPU VMs
|
| 293 |
+
|
| 294 |
+
To create a single machine with 8 TPU cores, follow the following Cloud TPU JAX
|
| 295 |
+
document:
|
| 296 |
+
https://cloud.google.com/tpu/docs/run-calculation-jax
|
| 297 |
+
|
| 298 |
+
To support large-scale vision research, more cores with multiple hosts are
|
| 299 |
+
recommended. Below we provide instructions on how to do it.
|
| 300 |
+
|
| 301 |
+
First, create some useful variables, which we be reused:
|
| 302 |
+
|
| 303 |
+
```
|
| 304 |
+
export NAME=<a name of the TPU deployment, e.g. my-tpu-machine>
|
| 305 |
+
export ZONE=<GCP geographical zone, e.g. europe-west4-a>
|
| 306 |
+
export GS_BUCKET_NAME=<Name of the storage bucket, e.g. my_bucket>
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
The following command line will create TPU VMs with 32 cores,
|
| 310 |
+
4 hosts.
|
| 311 |
+
|
| 312 |
+
```
|
| 313 |
+
gcloud compute tpus tpu-vm create $NAME --zone $ZONE --accelerator-type v3-32 --version tpu-ubuntu2204-base
|
| 314 |
+
```
|
| 315 |
+
|
| 316 |
+
## Install `big_vision` on TPU VMs
|
| 317 |
+
|
| 318 |
+
Fetch the `big_vision` repository, copy it to all TPU VM hosts, and install
|
| 319 |
+
dependencies.
|
| 320 |
+
|
| 321 |
+
```
|
| 322 |
+
git clone https://github.com/google-research/big_vision
|
| 323 |
+
gcloud compute tpus tpu-vm scp --recurse big_vision/big_vision $NAME: --zone=$ZONE --worker=all
|
| 324 |
+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "bash big_vision/run_tpu.sh"
|
| 325 |
+
```
|
| 326 |
+
|
| 327 |
+
## Download and prepare TFDS datasets
|
| 328 |
+
|
| 329 |
+
We recommend preparing `tfds` data locally as described above and then uploading
|
| 330 |
+
the data to `Google Cloud` bucket. However, if you prefer, the datasets which
|
| 331 |
+
do not require manual downloads can be prepared automatically using a TPU
|
| 332 |
+
machine as described below. Note that TPU machines have only 100 GB of disk
|
| 333 |
+
space, and multihost TPU slices do not allow for external disks to be attached
|
| 334 |
+
in a write mode, so the instructions below may not work for preparing large
|
| 335 |
+
datasets. As yet another alternative, we provide instructions
|
| 336 |
+
[on how to prepare `tfds` data on CPU-only GCP machine](#preparing-tfds-data-on-a-standalone-gcp-cpu-machine).
|
| 337 |
+
|
| 338 |
+
Specifically, the seven TFDS datasets used during evaluations will be generated
|
| 339 |
+
under `~/tensorflow_datasets` on TPU machine with this command:
|
| 340 |
+
|
| 341 |
+
```
|
| 342 |
+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "TFDS_DATA_DIR=~/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets cifar10 cifar100 oxford_iiit_pet oxford_flowers102 cars196 dtd uc_merced"
|
| 343 |
+
```
|
| 344 |
+
|
| 345 |
+
You can then copy the datasets to GS bucket, to make them accessible to all TPU workers.
|
| 346 |
+
|
| 347 |
+
```
|
| 348 |
+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "rm -r ~/tensorflow_datasets/downloads && gsutil cp -r ~/tensorflow_datasets gs://$GS_BUCKET_NAME"
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
If you want to integrate other public or custom datasets, i.e. imagenet2012,
|
| 352 |
+
please follow [the official guideline](https://www.tensorflow.org/datasets/catalog/overview).
|
| 353 |
+
|
| 354 |
+
## Pre-trained models
|
| 355 |
+
|
| 356 |
+
For the full list of pre-trained models check out the `load` function defined in
|
| 357 |
+
the same module as the model code. And for example config on how to use these
|
| 358 |
+
models, see `configs/transfer.py`.
|
| 359 |
+
|
| 360 |
+
## Run the transfer script on TPU VMs
|
| 361 |
+
|
| 362 |
+
The following command line fine-tunes a pre-trained `vit-i21k-augreg-b/32` model
|
| 363 |
+
on `cifar10` dataset.
|
| 364 |
+
|
| 365 |
+
```
|
| 366 |
+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03"
|
| 367 |
+
```
|
| 368 |
+
|
| 369 |
+
## Run the train script on TPU VMs
|
| 370 |
+
|
| 371 |
+
To train your own big_vision models on a large dataset,
|
| 372 |
+
e.g. `imagenet2012` ([prepare the TFDS dataset](https://www.tensorflow.org/datasets/catalog/imagenet2012)),
|
| 373 |
+
run the following command line.
|
| 374 |
+
|
| 375 |
+
```
|
| 376 |
+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/bit_i1k.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`"
|
| 377 |
+
```
|
| 378 |
+
|
| 379 |
+
## FSDP training.
|
| 380 |
+
|
| 381 |
+
`big_vision` supports flexible parameter and model sharding strategies.
|
| 382 |
+
Currently, we support a popular FSDP sharding via a simple config change, see [this config example](big_vision/configs/transfer.py).
|
| 383 |
+
For example, to run FSDP finetuning of a pretrained ViT-L model, run the following command (possible adjusting batch size depending on your hardware):
|
| 384 |
+
|
| 385 |
+
```
|
| 386 |
+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-l/16,dataset=oxford_iiit_pet,crop=resmall_crop,fsdp=True,batch_size=256 --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03"
|
| 387 |
+
```
|
| 388 |
+
|
| 389 |
+
## Image-text training with SigLIP.
|
| 390 |
+
|
| 391 |
+
A minimal example that uses public `coco` captions data:
|
| 392 |
+
|
| 393 |
+
```
|
| 394 |
+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.trainers.proj.image_text.siglip --config big_vision/configs/proj/image_text/siglip_lit_coco.py --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%Y-%m-%d_%H%M'`"
|
| 395 |
+
```
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
## Sometimes useful gcloud commands
|
| 400 |
+
|
| 401 |
+
- Destroy the TPU machines: `gcloud compute tpus tpu-vm delete $NAME --zone $ZONE`
|
| 402 |
+
- Remove all big_vision-related folders on all hosts: `gcloud compute tpus tpu-vm ssh $NAME --zone $ZONE --worker=all --command 'rm -rf ~/big_vision ~/bv_venv'`
|
| 403 |
+
|
| 404 |
+
## Preparing `tfds` data on a standalone GCP CPU machine.
|
| 405 |
+
|
| 406 |
+
First create a new machine and a disk (feel free to adjust exact machine type and disk settings/capacity):
|
| 407 |
+
|
| 408 |
+
```
|
| 409 |
+
export NAME_CPU_HOST=<A name of a CPU-only machine>
|
| 410 |
+
export NAME_DISK=<A name of a disk>
|
| 411 |
+
gcloud compute instances create $NAME_CPU_HOST --machine-type c3-standard-22 --zone $ZONE --image-family ubuntu-2204-lts --image-project ubuntu-os-cloud
|
| 412 |
+
gcloud compute disks create $NAME_DISK --size 1000GB --zone $ZONE --type pd-balanced
|
| 413 |
+
```
|
| 414 |
+
|
| 415 |
+
Now attach the disk to the newly create machine:
|
| 416 |
+
|
| 417 |
+
```
|
| 418 |
+
gcloud compute instances attach-disk $NAME_CPU_HOST --disk $NAME_DISK --zone $ZONE
|
| 419 |
+
```
|
| 420 |
+
|
| 421 |
+
Next, `ssh` to the machine `gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE` and
|
| 422 |
+
[follow instructions to format and mount the disk](https://cloud.google.com/compute/docs/disks/format-mount-disk-linux).
|
| 423 |
+
Let's assume it was mounted to `/mnt/disks/tfds`.
|
| 424 |
+
|
| 425 |
+
Almost there, now clone and set up `big_vision`:
|
| 426 |
+
|
| 427 |
+
```
|
| 428 |
+
gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "git clone https://github.com/google-research/big_vision.git && cd big_vision && sh big_vision/run_tpu.sh"
|
| 429 |
+
```
|
| 430 |
+
|
| 431 |
+
Finally, prepare the dataset (e.g. `coco_captions`) using the utility script and
|
| 432 |
+
copy the result to you google cloud bucket:
|
| 433 |
+
|
| 434 |
+
```
|
| 435 |
+
gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "cd big_vision && TFDS_DATA_DIR=/mnt/disks/tfds/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets coco_captions"
|
| 436 |
+
gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "rm -rf /mnt/disks/tfds/tensorflow_datasets/downloads && gsutil cp -r /mnt/disks/tfds/tensorflow_datasets gs://$GS_BUCKET_NAME"
|
| 437 |
+
```
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# ViT baseline
|
| 441 |
+
|
| 442 |
+
We provide a well-tuned ViT-S/16 baseline in the config file named
|
| 443 |
+
`vit_s16_i1k.py`. It achieves 76.5% accuracy on ImageNet validation split in
|
| 444 |
+
90 epochs of training, being a strong and simple starting point for research
|
| 445 |
+
on the ViT models.
|
| 446 |
+
|
| 447 |
+
Please see our [arXiv note](https://arxiv.org/abs/2205.01580) for more details
|
| 448 |
+
and if this baseline happens to by useful for your research, consider citing
|
| 449 |
+
|
| 450 |
+
```
|
| 451 |
+
@article{vit_baseline,
|
| 452 |
+
url = {https://arxiv.org/abs/2205.01580},
|
| 453 |
+
author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
|
| 454 |
+
title = {Better plain ViT baselines for ImageNet-1k},
|
| 455 |
+
journal={arXiv preprint arXiv:2205.01580},
|
| 456 |
+
year = {2022},
|
| 457 |
+
}
|
| 458 |
+
```
|
| 459 |
+
|
| 460 |
+
# Project specific commits
|
| 461 |
+
|
| 462 |
+
The last known commit where the specific project code is expected to work. The
|
| 463 |
+
core code and configs are expected to work at head.
|
| 464 |
+
|
| 465 |
+
| Project | Commit |
|
| 466 |
+
|------------|-----------------------------------------------------------------------------------------------|
|
| 467 |
+
| UViM | https://github.com/google-research/big_vision/commit/21bd6ebe253f070f584d8b777ad76f4abce51bef |
|
| 468 |
+
| image_text | https://github.com/google-research/big_vision/commit/8921d5141504390a8a4f7b2dacb3b3c042237290 |
|
| 469 |
+
| distill | https://github.com/google-research/big_vision/commit/2f3f493af048dbfd97555ff6060f31a0e686f17f |
|
| 470 |
+
| GSAM | WIP |
|
| 471 |
+
| CLIPPO | https://github.com/google-research/big_vision/commit/fd2d3bd2efc9d89ea959f16cd2f58ae8a495cd44 |
|
| 472 |
+
| CapPa | https://github.com/google-research/big_vision/commit/7ace659452dee4b68547575352c022a2eef587a5 |
|
| 473 |
+
| GIVT | https://github.com/google-research/big_vision/commit/0cb70881dd33b3343b769347dc19793c4994b8cb |
|
| 474 |
+
|
| 475 |
+
# Citing the codebase
|
| 476 |
+
|
| 477 |
+
If you found this codebase useful for your research, please consider using
|
| 478 |
+
the following BibTEX to cite it:
|
| 479 |
+
|
| 480 |
+
```
|
| 481 |
+
@misc{big_vision,
|
| 482 |
+
author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
|
| 483 |
+
title = {Big Vision},
|
| 484 |
+
year = {2022},
|
| 485 |
+
publisher = {GitHub},
|
| 486 |
+
journal = {GitHub repository},
|
| 487 |
+
howpublished = {\url{https://github.com/google-research/big_vision}}
|
| 488 |
+
}
|
| 489 |
+
```
|
| 490 |
+
|
| 491 |
+
# Disclaimer
|
| 492 |
+
|
| 493 |
+
This is not an official Google Product.
|
| 494 |
+
|
| 495 |
+
# License
|
| 496 |
+
|
| 497 |
+
Unless explicitly noted otherwise, everything in the big_vision codebase
|
| 498 |
+
(including models and colabs) is released under the Apache2 license.
|
| 499 |
+
See the LICENSE file for the full license text.
|
big_vision/__init__.py
ADDED
|
File without changes
|
big_vision/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (179 Bytes). View file
|
|
|
big_vision/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (52.6 kB). View file
|
|
|
big_vision/configs/__init__.py
ADDED
|
File without changes
|
big_vision/configs/bit_i1k.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Pre-training BiT on ILSVRC-2012 as in https://arxiv.org/abs/1912.11370
|
| 17 |
+
|
| 18 |
+
Run training of a BiT-ResNet-50x1 variant, which takes ~32min on v3-128:
|
| 19 |
+
|
| 20 |
+
big_vision.train \
|
| 21 |
+
--config big_vision/configs/bit_i1k.py \
|
| 22 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
| 23 |
+
--config.model.depth 50 --config.model.width 1
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 27 |
+
import ml_collections as mlc
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_config(runlocal=False):
|
| 31 |
+
"""Config for training on ImageNet-1k."""
|
| 32 |
+
config = mlc.ConfigDict()
|
| 33 |
+
|
| 34 |
+
config.seed = 0
|
| 35 |
+
config.total_epochs = 90
|
| 36 |
+
config.num_classes = 1000
|
| 37 |
+
config.loss = 'softmax_xent'
|
| 38 |
+
|
| 39 |
+
config.input = dict()
|
| 40 |
+
config.input.data = dict(
|
| 41 |
+
name='imagenet2012',
|
| 42 |
+
split='train[:99%]',
|
| 43 |
+
)
|
| 44 |
+
config.input.batch_size = 4096
|
| 45 |
+
config.input.cache_raw = True # Needs up to 120GB of RAM!
|
| 46 |
+
config.input.shuffle_buffer_size = 250_000 # Per host.
|
| 47 |
+
|
| 48 |
+
pp_common = '|onehot(1000, key="{lbl}", key_result="labels")'
|
| 49 |
+
pp_common += '|value_range(-1, 1)|keep("image", "labels")'
|
| 50 |
+
config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label')
|
| 51 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
|
| 52 |
+
|
| 53 |
+
config.log_training_steps = 50
|
| 54 |
+
config.ckpt_steps = 1000
|
| 55 |
+
|
| 56 |
+
# Model section
|
| 57 |
+
config.model_name = 'bit'
|
| 58 |
+
config.model = dict(
|
| 59 |
+
depth=50, # You can also pass e.g. [3, 5, 10, 2]
|
| 60 |
+
width=1.0,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Optimizer section
|
| 64 |
+
config.optax_name = 'big_vision.momentum_hp'
|
| 65 |
+
config.grad_clip_norm = 1.0
|
| 66 |
+
|
| 67 |
+
# linear scaling rule. Don't forget to sweep if sweeping batch_size.
|
| 68 |
+
config.wd = (1e-4 / 256) * config.input.batch_size
|
| 69 |
+
config.lr = (0.1 / 256) * config.input.batch_size
|
| 70 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=1000)
|
| 71 |
+
|
| 72 |
+
# Eval section
|
| 73 |
+
def get_eval(split, dataset='imagenet2012'):
|
| 74 |
+
return dict(
|
| 75 |
+
type='classification',
|
| 76 |
+
data=dict(name=dataset, split=split),
|
| 77 |
+
pp_fn=pp_eval.format(lbl='label'),
|
| 78 |
+
loss_name=config.loss,
|
| 79 |
+
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
|
| 80 |
+
cache='final_data',
|
| 81 |
+
)
|
| 82 |
+
config.evals = {}
|
| 83 |
+
config.evals.train = get_eval('train[:2%]')
|
| 84 |
+
config.evals.minival = get_eval('train[99%:]')
|
| 85 |
+
config.evals.val = get_eval('validation')
|
| 86 |
+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
|
| 87 |
+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
|
| 88 |
+
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
|
| 89 |
+
|
| 90 |
+
# config.evals.fewshot = get_fewshot_lsr(runlocal=runlocal)
|
| 91 |
+
# config.evals.fewshot.log_steps = 1000
|
| 92 |
+
|
| 93 |
+
if runlocal:
|
| 94 |
+
config.input.batch_size = 32
|
| 95 |
+
config.input.cache_raw = False
|
| 96 |
+
config.input.shuffle_buffer_size = 100
|
| 97 |
+
|
| 98 |
+
local_eval = config.evals.val
|
| 99 |
+
config.evals = {'val': local_eval}
|
| 100 |
+
config.evals.val.cache = 'none'
|
| 101 |
+
|
| 102 |
+
return config
|
big_vision/configs/bit_i21k.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""A config for pre-training BiT on ImageNet-21k.
|
| 17 |
+
|
| 18 |
+
This config relies on the Imagenet-21k tfds dataset, which is not yet
|
| 19 |
+
available publicly in TFDS. We intend to add the dataset to public TFDS soon,
|
| 20 |
+
and this config will then be runnable.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 24 |
+
import ml_collections as mlc
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_config():
|
| 28 |
+
"""Config for training on imagenet-21k."""
|
| 29 |
+
config = mlc.ConfigDict()
|
| 30 |
+
|
| 31 |
+
config.seed = 0
|
| 32 |
+
config.total_epochs = 90
|
| 33 |
+
config.num_classes = 21843
|
| 34 |
+
config.init_head_bias = -10.0
|
| 35 |
+
config.loss = 'sigmoid_xent'
|
| 36 |
+
|
| 37 |
+
config.input = dict()
|
| 38 |
+
config.input.data = dict(
|
| 39 |
+
name='imagenet21k',
|
| 40 |
+
split='full[51200:]',
|
| 41 |
+
)
|
| 42 |
+
config.input.batch_size = 4096
|
| 43 |
+
config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
|
| 44 |
+
|
| 45 |
+
pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
|
| 46 |
+
pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
|
| 47 |
+
pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
|
| 48 |
+
config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k
|
| 49 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)'
|
| 50 |
+
|
| 51 |
+
config.log_training_steps = 50
|
| 52 |
+
config.ckpt_steps = 1000
|
| 53 |
+
|
| 54 |
+
# Model section
|
| 55 |
+
config.model_name = 'bit_paper'
|
| 56 |
+
config.model = dict(depth=50, width=1.0)
|
| 57 |
+
|
| 58 |
+
# Optimizer section
|
| 59 |
+
config.optax_name = 'big_vision.momentum_hp'
|
| 60 |
+
config.grad_clip_norm = 1.0
|
| 61 |
+
|
| 62 |
+
# linear scaling rule. Don't forget to sweep if sweeping batch_size.
|
| 63 |
+
config.lr = (0.03 / 256) * config.input.batch_size
|
| 64 |
+
config.wd = (3e-5 / 256) * config.input.batch_size
|
| 65 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=5000)
|
| 66 |
+
|
| 67 |
+
# Evaluations on i21k itself.
|
| 68 |
+
def eval_i21k(split):
|
| 69 |
+
return dict(
|
| 70 |
+
type='classification',
|
| 71 |
+
data={**config.input.data, 'split': split},
|
| 72 |
+
pp_fn=pp_eval + pp_common_i21k,
|
| 73 |
+
loss_name=config.loss,
|
| 74 |
+
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
|
| 75 |
+
)
|
| 76 |
+
config.evals = {}
|
| 77 |
+
config.evals.test = eval_i21k('full[:25_600]')
|
| 78 |
+
config.evals.val = eval_i21k('full[25_600:51_200]')
|
| 79 |
+
config.evals.train = eval_i21k('full[51_200:76_800]')
|
| 80 |
+
|
| 81 |
+
# Few-shot evaluators
|
| 82 |
+
config.evals.fewshot = get_fewshot_lsr()
|
| 83 |
+
config.evals.fewshot.log_steps = 25_000
|
| 84 |
+
|
| 85 |
+
return config
|
big_vision/configs/common.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""A few things commonly used across A LOT of config files."""
|
| 16 |
+
|
| 17 |
+
import string
|
| 18 |
+
|
| 19 |
+
import ml_collections as mlc
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def input_for_quicktest(config_input, quicktest):
|
| 23 |
+
if quicktest:
|
| 24 |
+
config_input.batch_size = 8
|
| 25 |
+
config_input.shuffle_buffer_size = 10
|
| 26 |
+
config_input.cache_raw = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parse_arg(arg, lazy=False, **spec):
|
| 30 |
+
"""Makes ConfigDict's get_config single-string argument more usable.
|
| 31 |
+
|
| 32 |
+
Example use in the config file:
|
| 33 |
+
|
| 34 |
+
import big_vision.configs.common as bvcc
|
| 35 |
+
def get_config(arg):
|
| 36 |
+
arg = bvcc.parse_arg(arg,
|
| 37 |
+
res=(224, int),
|
| 38 |
+
runlocal=False,
|
| 39 |
+
schedule='short',
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# ...
|
| 43 |
+
|
| 44 |
+
config.shuffle_buffer = 250_000 if not arg.runlocal else 50
|
| 45 |
+
|
| 46 |
+
Ways that values can be passed when launching:
|
| 47 |
+
|
| 48 |
+
--config amazing.py:runlocal,schedule=long,res=128
|
| 49 |
+
--config amazing.py:res=128
|
| 50 |
+
--config amazing.py:runlocal # A boolean needs no value for "true".
|
| 51 |
+
--config amazing.py:runlocal=False # Explicit false boolean.
|
| 52 |
+
--config amazing.py:128 # The first spec entry may be passed unnamed alone.
|
| 53 |
+
|
| 54 |
+
Uses strict bool conversion (converting 'True', 'true' to True, and 'False',
|
| 55 |
+
'false', '' to False).
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
arg: the string argument that's passed to get_config.
|
| 59 |
+
lazy: allow lazy parsing of arguments, which are not in spec. For these,
|
| 60 |
+
the type is auto-extracted in dependence of most complex possible type.
|
| 61 |
+
**spec: the name and default values of the expected options.
|
| 62 |
+
If the value is a tuple, the value's first element is the default value,
|
| 63 |
+
and the second element is a function called to convert the string.
|
| 64 |
+
Otherwise the type is automatically extracted from the default value.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
ConfigDict object with extracted type-converted values.
|
| 68 |
+
"""
|
| 69 |
+
# Normalize arg and spec layout.
|
| 70 |
+
arg = arg or '' # Normalize None to empty string
|
| 71 |
+
spec = {k: get_type_with_default(v) for k, v in spec.items()}
|
| 72 |
+
|
| 73 |
+
result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only.
|
| 74 |
+
|
| 75 |
+
# Expand convenience-cases for a single parameter without = sign.
|
| 76 |
+
if arg and ',' not in arg and '=' not in arg:
|
| 77 |
+
# (think :runlocal) If it's the name of sth in the spec (or there is no
|
| 78 |
+
# spec), it's that in bool.
|
| 79 |
+
if arg in spec or not spec:
|
| 80 |
+
arg = f'{arg}=True'
|
| 81 |
+
# Otherwise, it is the value for the first entry in the spec.
|
| 82 |
+
else:
|
| 83 |
+
arg = f'{list(spec.keys())[0]}={arg}'
|
| 84 |
+
# Yes, we rely on Py3.7 insertion order!
|
| 85 |
+
|
| 86 |
+
# Now, expand the `arg` string into a dict of keys and values:
|
| 87 |
+
raw_kv = {raw_arg.split('=')[0]:
|
| 88 |
+
raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True'
|
| 89 |
+
for raw_arg in arg.split(',') if raw_arg}
|
| 90 |
+
|
| 91 |
+
# And go through the spec, using provided or default value for each:
|
| 92 |
+
for name, (default, type_fn) in spec.items():
|
| 93 |
+
val = raw_kv.pop(name, None)
|
| 94 |
+
result[name] = type_fn(val) if val is not None else default
|
| 95 |
+
|
| 96 |
+
if raw_kv:
|
| 97 |
+
if lazy: # Process args which are not in spec.
|
| 98 |
+
for k, v in raw_kv.items():
|
| 99 |
+
result[k] = autotype(v)
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f'Unhandled config args remain: {raw_kv}')
|
| 102 |
+
|
| 103 |
+
return result
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_type_with_default(v):
|
| 107 |
+
"""Returns (v, string_to_v_type) with lenient bool parsing."""
|
| 108 |
+
# For bool, do safe string conversion.
|
| 109 |
+
if isinstance(v, bool):
|
| 110 |
+
def strict_bool(x):
|
| 111 |
+
assert x.lower() in {'true', 'false', ''}
|
| 112 |
+
return x.lower() == 'true'
|
| 113 |
+
return (v, strict_bool)
|
| 114 |
+
# If already a (default, type) tuple, use that.
|
| 115 |
+
if isinstance(v, (tuple, list)):
|
| 116 |
+
assert len(v) == 2 and isinstance(v[1], type), (
|
| 117 |
+
'List or tuple types are currently not supported because we use `,` as'
|
| 118 |
+
' dumb delimiter. Contributions (probably using ast) welcome. You can'
|
| 119 |
+
' unblock by using a string with eval(s.replace(";", ",")) or similar')
|
| 120 |
+
return (v[0], v[1])
|
| 121 |
+
# Otherwise, derive the type from the default value.
|
| 122 |
+
return (v, type(v))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def autotype(x):
|
| 126 |
+
"""Auto-converts string to bool/int/float if possible."""
|
| 127 |
+
assert isinstance(x, str)
|
| 128 |
+
if x.lower() in {'true', 'false'}:
|
| 129 |
+
return x.lower() == 'true' # Returns as bool.
|
| 130 |
+
try:
|
| 131 |
+
return int(x) # Returns as int.
|
| 132 |
+
except ValueError:
|
| 133 |
+
try:
|
| 134 |
+
return float(x) # Returns as float.
|
| 135 |
+
except ValueError:
|
| 136 |
+
return x # Returns as str.
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def pack_arg(**kw):
|
| 140 |
+
"""Packs key-word args as a string to be parsed by `parse_arg()`."""
|
| 141 |
+
for v in kw.values():
|
| 142 |
+
assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}"
|
| 143 |
+
return ','.join([f'{k}={v}' for k, v in kw.items()])
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def arg(**kw):
|
| 147 |
+
"""Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg."""
|
| 148 |
+
return {'config_arg': pack_arg(**kw), **kw}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _get_field_ref(config_dict, field_name):
|
| 152 |
+
path = field_name.split('.')
|
| 153 |
+
for field in path[:-1]:
|
| 154 |
+
config_dict = getattr(config_dict, field)
|
| 155 |
+
return config_dict.get_ref(path[-1])
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def format_str(format_string, config):
|
| 159 |
+
"""Format string with reference fields from config.
|
| 160 |
+
|
| 161 |
+
This makes it easy to build preprocess strings that contain references to
|
| 162 |
+
fields tha are edited after. E.g.:
|
| 163 |
+
|
| 164 |
+
```
|
| 165 |
+
config = mlc.ConficDict()
|
| 166 |
+
config.res = (256, 256)
|
| 167 |
+
config.pp = bvcc.format_str('resize({res})', config)
|
| 168 |
+
...
|
| 169 |
+
# if config.res is modified (e.g. via sweeps) it will propagate to pp field:
|
| 170 |
+
config.res = (512, 512)
|
| 171 |
+
assert config.pp == 'resize((512, 512))'
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
format_string: string to format with references.
|
| 176 |
+
config: ConfigDict to get references to format the string.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
A reference field which renders a string using references to config fields.
|
| 180 |
+
"""
|
| 181 |
+
output = ''
|
| 182 |
+
parts = string.Formatter().parse(format_string)
|
| 183 |
+
for (literal_text, field_name, format_spec, conversion) in parts:
|
| 184 |
+
assert not format_spec and not conversion
|
| 185 |
+
output += literal_text
|
| 186 |
+
if field_name:
|
| 187 |
+
output += _get_field_ref(config, field_name).to_str()
|
| 188 |
+
return output
|
big_vision/configs/common_fewshot.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Most common few-shot eval configuration."""
|
| 16 |
+
|
| 17 |
+
import ml_collections as mlc
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_fewshot_lsr(target_resolution=224, resize_resolution=256,
|
| 21 |
+
runlocal=False, pp=None, **kw):
|
| 22 |
+
"""Returns a standard-ish fewshot eval configuration."""
|
| 23 |
+
kw.setdefault('representation_layer', 'pre_logits')
|
| 24 |
+
kw.setdefault('shots', (1, 5, 10, 25))
|
| 25 |
+
kw.setdefault('l2_reg', 2.0 ** 10)
|
| 26 |
+
kw.setdefault('num_seeds', 3)
|
| 27 |
+
kw.setdefault('prefix', '') # No prefix as we already use a/ z/ and zz/
|
| 28 |
+
|
| 29 |
+
# Backward-compatible default:
|
| 30 |
+
if not any(f'log_{x}' in kw for x in ['steps', 'percent', 'examples', 'epochs']): # pylint: disable=line-too-long
|
| 31 |
+
kw['log_steps'] = 25_000
|
| 32 |
+
|
| 33 |
+
config = mlc.ConfigDict(kw)
|
| 34 |
+
config.type = 'fewshot_lsr'
|
| 35 |
+
config.datasets = {
|
| 36 |
+
'caltech': ('caltech101', 'train', 'test'), # copybara:srtip
|
| 37 |
+
'cars': ('cars196:2.1.0', 'train', 'test'),
|
| 38 |
+
'cifar100': ('cifar100', 'train', 'test'),
|
| 39 |
+
'dtd': ('dtd', 'train', 'test'),
|
| 40 |
+
# The first 65000 ImageNet samples have at least 30 shots per any class.
|
| 41 |
+
# Commented out by default because needs manual download.
|
| 42 |
+
# 'imagenet': ('imagenet2012', 'train[:65000]', 'validation'),
|
| 43 |
+
'pets': ('oxford_iiit_pet', 'train', 'test'),
|
| 44 |
+
'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'),
|
| 45 |
+
} if not runlocal else {
|
| 46 |
+
'pets': ('oxford_iiit_pet', 'train', 'test'),
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
pp = pp or '|'.join([
|
| 50 |
+
'decode',
|
| 51 |
+
f'resize({resize_resolution})',
|
| 52 |
+
f'central_crop({target_resolution})',
|
| 53 |
+
'value_range(-1,1)'
|
| 54 |
+
])
|
| 55 |
+
pp += '|keep("image", "label")'
|
| 56 |
+
config.pp_train = pp
|
| 57 |
+
config.pp_eval = pp
|
| 58 |
+
config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)]
|
| 59 |
+
|
| 60 |
+
return config
|
big_vision/configs/load_and_eval.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pytype: disable=not-writable,attribute-error
|
| 16 |
+
# pylint: disable=line-too-long,missing-function-docstring
|
| 17 |
+
r"""A config to load and eval key model using the core train.py.
|
| 18 |
+
|
| 19 |
+
The runtime varies widely depending on the model, but each one should reproduce
|
| 20 |
+
the corresponding paper's numbers.
|
| 21 |
+
This configuration makes use of the "arg" to get_config to select which model
|
| 22 |
+
to run, so a few examples are given below:
|
| 23 |
+
|
| 24 |
+
Run and evaluate a BiT-M ResNet-50x1 model that was transferred to i1k:
|
| 25 |
+
|
| 26 |
+
big_vision.train \
|
| 27 |
+
--config big_vision/configs/load_and_eval.py:name=bit_paper,batch_size=8 \
|
| 28 |
+
--config.model_init M-imagenet2012 --config.model.width 1 --config.model.depth 50
|
| 29 |
+
|
| 30 |
+
Run and evaluate the recommended ViT-B/32 from "how to train your vit" paper:
|
| 31 |
+
|
| 32 |
+
big_vision.train \
|
| 33 |
+
--config big_vision/configs/load_and_eval.py:name=vit_i21k,batch_size=8 \
|
| 34 |
+
--config.model.variant B/32 --config.model_init howto-i21k-B/32
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import big_vision.configs.common as bvcc
|
| 38 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def eval_only(config, batch_size, spec_for_init):
|
| 42 |
+
"""Set a few configs that turn trainer into (almost) eval-only."""
|
| 43 |
+
config.total_steps = 0
|
| 44 |
+
config.input = {}
|
| 45 |
+
config.input.batch_size = batch_size
|
| 46 |
+
config.input.data = dict(name='bv:dummy', spec=spec_for_init)
|
| 47 |
+
config.optax_name = 'identity'
|
| 48 |
+
config.lr = 0.0
|
| 49 |
+
|
| 50 |
+
config.mesh = [('data', -1)]
|
| 51 |
+
config.sharding_strategy = [('params/.*', 'fsdp(axis="data")')]
|
| 52 |
+
config.sharding_rules = [('act_batch', ('data',))]
|
| 53 |
+
|
| 54 |
+
return config
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_config(arg=''):
|
| 58 |
+
config = bvcc.parse_arg(arg, name='bit_paper', batch_size=4)
|
| 59 |
+
|
| 60 |
+
# Make the config eval-only by setting some dummies.
|
| 61 |
+
eval_only(config, config.batch_size, spec_for_init=dict(
|
| 62 |
+
image=dict(shape=(224, 224, 3), dtype='float32'),
|
| 63 |
+
))
|
| 64 |
+
|
| 65 |
+
config.evals = dict(fewshot=get_fewshot_lsr())
|
| 66 |
+
|
| 67 |
+
# Just calls the function with the name given as `config`.
|
| 68 |
+
# Could also be a giant if-block if you're into that kind of thing.
|
| 69 |
+
globals()[config.name](config)
|
| 70 |
+
return config
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def bit_paper(config):
|
| 74 |
+
config.num_classes = 1000
|
| 75 |
+
|
| 76 |
+
config.model_name = 'bit_paper'
|
| 77 |
+
config.model_init = 'M-imagenet2012' # M = i21k, -imagenet2012 = fine-tuned
|
| 78 |
+
config.model = dict(width=1, depth=50)
|
| 79 |
+
|
| 80 |
+
def get_eval(split, lbl, dataset='imagenet2012_real'):
|
| 81 |
+
return dict(
|
| 82 |
+
type='classification',
|
| 83 |
+
data=dict(name=dataset, split=split),
|
| 84 |
+
loss_name='softmax_xent',
|
| 85 |
+
cache='none', # Only run once, on low-mem machine.
|
| 86 |
+
pp_fn=(
|
| 87 |
+
'decode|resize(384)|value_range(-1, 1)'
|
| 88 |
+
f'|onehot(1000, key="{lbl}", key_result="labels")'
|
| 89 |
+
'|keep("image", "labels")'
|
| 90 |
+
),
|
| 91 |
+
)
|
| 92 |
+
config.evals.test = get_eval('validation', 'original_label')
|
| 93 |
+
config.evals.real = get_eval('validation', 'real_label')
|
| 94 |
+
config.evals.v2 = get_eval('test', 'label', 'imagenet_v2')
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def vit_i1k(config):
|
| 98 |
+
config.num_classes = 1000
|
| 99 |
+
|
| 100 |
+
config.model_name = 'vit'
|
| 101 |
+
config.model_init = '' # Will be set in sweep.
|
| 102 |
+
config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
|
| 103 |
+
rep_size=True)
|
| 104 |
+
|
| 105 |
+
config.evals.val = dict(
|
| 106 |
+
type='classification',
|
| 107 |
+
data=dict(name='imagenet2012', split='validation'),
|
| 108 |
+
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
|
| 109 |
+
loss_name='softmax_xent',
|
| 110 |
+
cache='none', # Only run once, on low-mem machine.
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def mlp_mixer_i1k(config):
|
| 115 |
+
config.num_classes = 1000
|
| 116 |
+
|
| 117 |
+
config.model_name = 'mlp_mixer'
|
| 118 |
+
config.model_init = '' # Will be set in sweep.
|
| 119 |
+
config.model = dict(variant='L/16')
|
| 120 |
+
|
| 121 |
+
config.evals.val = dict(
|
| 122 |
+
type='classification',
|
| 123 |
+
data=dict(name='imagenet2012', split='validation'),
|
| 124 |
+
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
|
| 125 |
+
loss_name='softmax_xent',
|
| 126 |
+
cache='none', # Only run once, on low-mem machine.
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def vit_i21k(config):
|
| 131 |
+
config.num_classes = 21843
|
| 132 |
+
|
| 133 |
+
config.model_name = 'vit'
|
| 134 |
+
config.model_init = '' # Will be set in sweep.
|
| 135 |
+
config.model = dict(variant='B/32', pool_type='tok')
|
| 136 |
+
|
| 137 |
+
config.evals.val = dict(
|
| 138 |
+
type='classification',
|
| 139 |
+
data=dict(name='imagenet21k', split='full[:51200]'),
|
| 140 |
+
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(21843)|keep("image", "labels")',
|
| 141 |
+
loss_name='sigmoid_xent',
|
| 142 |
+
cache='none', # Only run once, on low-mem machine.
|
| 143 |
+
)
|
big_vision/configs/mlp_mixer_i1k.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""A config for training MLP-Mixer-B/16 model on ILSVRC-2012 ("ImageNet-1k").
|
| 17 |
+
|
| 18 |
+
Achieves 76.3% top-1 accuracy on the test split in 2h11m on TPU v3-128
|
| 19 |
+
with 300 epochs. A shorter 60 epochs run is expected to get to 70.5% in 27m.
|
| 20 |
+
|
| 21 |
+
big_vision.train \
|
| 22 |
+
--config big_vision/configs/mlp_mixer_i1k.py \
|
| 23 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 27 |
+
import ml_collections as mlc
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_config(mode=None):
|
| 31 |
+
"""Config for training Mixer on i1k."""
|
| 32 |
+
config = mlc.ConfigDict()
|
| 33 |
+
|
| 34 |
+
config.seed = 0
|
| 35 |
+
config.total_epochs = 300
|
| 36 |
+
config.num_classes = 1000
|
| 37 |
+
config.loss = 'sigmoid_xent'
|
| 38 |
+
config.init_head_bias = -6.9
|
| 39 |
+
|
| 40 |
+
config.input = dict()
|
| 41 |
+
config.input.data = dict(
|
| 42 |
+
name='imagenet2012',
|
| 43 |
+
split='train[:99%]',
|
| 44 |
+
)
|
| 45 |
+
config.input.batch_size = 4096
|
| 46 |
+
config.input.cache_raw = True # Needs up to 120GB of RAM!
|
| 47 |
+
config.input.shuffle_buffer_size = 250_000
|
| 48 |
+
|
| 49 |
+
config.input.pp = (
|
| 50 |
+
'decode_jpeg_and_inception_crop(224)'
|
| 51 |
+
'|flip_lr'
|
| 52 |
+
'|randaug(2,15)'
|
| 53 |
+
'|value_range(-1, 1)'
|
| 54 |
+
'|onehot(1000, key="label", key_result="labels")'
|
| 55 |
+
'|keep("image", "labels")'
|
| 56 |
+
)
|
| 57 |
+
pp_eval = (
|
| 58 |
+
'decode'
|
| 59 |
+
'|resize_small(256)|central_crop(224)'
|
| 60 |
+
'|value_range(-1, 1)'
|
| 61 |
+
'|onehot(1000, key="{lbl}", key_result="labels")'
|
| 62 |
+
'|keep("image", "labels")'
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# To continue using the near-defunct randaug op.
|
| 66 |
+
config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
|
| 67 |
+
|
| 68 |
+
config.log_training_steps = 50
|
| 69 |
+
config.ckpt_steps = 1000
|
| 70 |
+
|
| 71 |
+
config.prefetch_to_device = 2
|
| 72 |
+
|
| 73 |
+
# Model section
|
| 74 |
+
config.model_name = 'mlp_mixer'
|
| 75 |
+
config.model = dict()
|
| 76 |
+
config.model.variant = 'B/16'
|
| 77 |
+
config.model.stoch_depth = 0.1
|
| 78 |
+
|
| 79 |
+
config.mixup = dict(fold_in=None, p=0.5)
|
| 80 |
+
|
| 81 |
+
# Optimizer section
|
| 82 |
+
config.optax_name = 'scale_by_adam'
|
| 83 |
+
config.grad_clip_norm = 1.
|
| 84 |
+
|
| 85 |
+
config.lr = 0.001
|
| 86 |
+
config.wd = 1e-4
|
| 87 |
+
config.schedule = dict(
|
| 88 |
+
decay_type='linear',
|
| 89 |
+
warmup_steps=10_000,
|
| 90 |
+
linear_end=1e-5,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Eval section
|
| 94 |
+
def get_eval(split, dataset='imagenet2012'):
|
| 95 |
+
return dict(
|
| 96 |
+
type='classification',
|
| 97 |
+
data=dict(name=dataset, split=split),
|
| 98 |
+
pp_fn=pp_eval.format(lbl='label'),
|
| 99 |
+
loss_name=config.loss,
|
| 100 |
+
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
|
| 101 |
+
cache_final=mode != 'gpu8',
|
| 102 |
+
)
|
| 103 |
+
config.evals = {}
|
| 104 |
+
config.evals.train = get_eval('train[:2%]')
|
| 105 |
+
config.evals.minival = get_eval('train[99%:]')
|
| 106 |
+
config.evals.val = get_eval('validation')
|
| 107 |
+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
|
| 108 |
+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
|
| 109 |
+
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
|
| 110 |
+
|
| 111 |
+
config.fewshot = get_fewshot_lsr()
|
| 112 |
+
|
| 113 |
+
if mode == 'gpu8':
|
| 114 |
+
config.total_epochs = 60
|
| 115 |
+
config.input.batch_size = 512
|
| 116 |
+
config.input.cache_raw = False
|
| 117 |
+
if mode == 'regression_test':
|
| 118 |
+
config.total_epochs = 60
|
| 119 |
+
|
| 120 |
+
return config
|
big_vision/configs/transfer.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long,missing-function-docstring
|
| 16 |
+
r"""A config for transferring vit-augreg.
|
| 17 |
+
|
| 18 |
+
Best HP selected on (mini)val, expected test results (repeated 5 times):
|
| 19 |
+
|
| 20 |
+
ViT-Augreg-B/32:
|
| 21 |
+
Dataset, crop, learning rate, mean (%), range (%)
|
| 22 |
+
- ImageNet, inception_crop, 0.03, 83.27, [83.22...83.33]
|
| 23 |
+
- Cifar10, resmall_crop, 0.003, 98.55, [98.46...98.6]
|
| 24 |
+
- Cifar100, resmall_crop, 0.01, 91.35, [91.09...91.62]
|
| 25 |
+
- Pets, inception_crop, 0.003, 93.78, [93.62...94.00]
|
| 26 |
+
- Flowers, inception_crop, 0.003, 99.43, [99.42...99.45]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
Command to run:
|
| 30 |
+
big_vision.train \
|
| 31 |
+
--config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop \
|
| 32 |
+
--workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import big_vision.configs.common as bvcc
|
| 36 |
+
import ml_collections as mlc
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _set_model(config, model):
|
| 40 |
+
"""Load pre-trained models: vit or bit."""
|
| 41 |
+
# Reset the head to init (of zeros) when transferring.
|
| 42 |
+
config.model_load = dict(dont_load=['head/kernel', 'head/bias'])
|
| 43 |
+
|
| 44 |
+
if model == 'vit-i21k-augreg-b/32':
|
| 45 |
+
# Load "recommended" upstream B/32 from https://arxiv.org/abs/2106.10270
|
| 46 |
+
config.model_name = 'vit'
|
| 47 |
+
config.model_init = 'howto-i21k-B/32'
|
| 48 |
+
config.model = dict(variant='B/32', pool_type='tok')
|
| 49 |
+
elif model == 'vit-i21k-augreg-l/16':
|
| 50 |
+
config.model_name = 'vit'
|
| 51 |
+
config.model_init = 'howto-i21k-L/16'
|
| 52 |
+
config.model = dict(variant='L/16', pool_type='tok')
|
| 53 |
+
elif model == 'vit-s16':
|
| 54 |
+
config.model_name = 'vit'
|
| 55 |
+
config.model_init = 'i1k-s16-300ep'
|
| 56 |
+
config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
|
| 57 |
+
rep_size=True)
|
| 58 |
+
elif model == 'bit-m-r50x1':
|
| 59 |
+
config.model_name = 'bit_paper'
|
| 60 |
+
config.model_init = 'M'
|
| 61 |
+
config.model = dict(depth=50, width=1)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f'Unknown model: {model}, please define customized model.')
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _set_dataset(config, dataset, crop='inception_crop', h_res=448, l_res=384):
|
| 67 |
+
if dataset == 'cifar10':
|
| 68 |
+
_set_task(config, 'cifar10', 'train[:98%]', 'train[98%:]', 'test', 10, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
|
| 69 |
+
elif dataset == 'cifar100':
|
| 70 |
+
_set_task(config, 'cifar100', 'train[:98%]', 'train[98%:]', 'test', 100, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
|
| 71 |
+
elif dataset == 'imagenet2012':
|
| 72 |
+
_set_task(config, 'imagenet2012', 'train[:99%]', 'train[99%:]', 'validation', 1000, steps=20_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
|
| 73 |
+
_set_imagenet_variants(config)
|
| 74 |
+
elif dataset == 'oxford_iiit_pet':
|
| 75 |
+
_set_task(config, 'oxford_iiit_pet', 'train[:90%]', 'train[90%:]', 'test', 37, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res)
|
| 76 |
+
elif dataset == 'oxford_flowers102':
|
| 77 |
+
_set_task(config, 'oxford_flowers102', 'train[:90%]', 'train[90%:]', 'test', 102, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res)
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f'Unknown dataset: {dataset}, please define customized dataset.')
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _set_task(config, dataset, train, val, test, n_cls,
|
| 84 |
+
steps=20_000, warmup=500, lbl='label', crop='resmall_crop',
|
| 85 |
+
flip=True, h_res=448, l_res=384):
|
| 86 |
+
"""Vision task with val and test splits."""
|
| 87 |
+
config.total_steps = steps
|
| 88 |
+
config.schedule = dict(
|
| 89 |
+
warmup_steps=warmup,
|
| 90 |
+
decay_type='cosine',
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
config.input.data = dict(name=dataset, split=train)
|
| 94 |
+
pp_common = (
|
| 95 |
+
'|value_range(-1, 1)|'
|
| 96 |
+
f'onehot({n_cls}, key="{lbl}", key_result="labels")|'
|
| 97 |
+
'keep("image", "labels")'
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if crop == 'inception_crop':
|
| 101 |
+
pp_train = f'decode|inception_crop({l_res})'
|
| 102 |
+
elif crop == 'resmall_crop':
|
| 103 |
+
pp_train = f'decode|resize_small({h_res})|random_crop({l_res})'
|
| 104 |
+
elif crop == 'resize_crop':
|
| 105 |
+
pp_train = f'decode|resize({h_res})|random_crop({l_res})'
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(f'Unknown crop: {crop}. Must be one of: '
|
| 108 |
+
'inception_crop, resmall_crop, resize_crop')
|
| 109 |
+
if flip:
|
| 110 |
+
pp_train += '|flip_lr'
|
| 111 |
+
config.input.pp = pp_train + pp_common
|
| 112 |
+
|
| 113 |
+
pp = f'decode|resize_small({h_res})|central_crop({l_res})' + pp_common
|
| 114 |
+
config.num_classes = n_cls
|
| 115 |
+
|
| 116 |
+
def get_eval(split):
|
| 117 |
+
return dict(
|
| 118 |
+
type='classification',
|
| 119 |
+
data=dict(name=dataset, split=split),
|
| 120 |
+
loss_name='softmax_xent',
|
| 121 |
+
log_steps=100,
|
| 122 |
+
pp_fn=pp,
|
| 123 |
+
)
|
| 124 |
+
config.evals = dict(val=get_eval(val), test=get_eval(test))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _set_imagenet_variants(config, h_res=448, l_res=384):
|
| 128 |
+
"""Evaluation tasks on ImageNet variants: v2 and real."""
|
| 129 |
+
pp = (f'decode|resize_small({h_res})|central_crop({l_res})'
|
| 130 |
+
'|value_range(-1, 1)|onehot(1000, key="{lbl}", key_result="labels")|'
|
| 131 |
+
'keep("image", "labels")'
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Special-case rename for i1k (val+test -> minival+val)
|
| 135 |
+
config.evals.minival = config.evals.val
|
| 136 |
+
config.evals.val = config.evals.test
|
| 137 |
+
# NOTE: keep test == val for convenience in subsequent analysis.
|
| 138 |
+
|
| 139 |
+
config.evals.real = dict(type='classification')
|
| 140 |
+
config.evals.real.data = dict(name='imagenet2012_real', split='validation')
|
| 141 |
+
config.evals.real.pp_fn = pp.format(lbl='real_label')
|
| 142 |
+
config.evals.real.loss_name = config.loss
|
| 143 |
+
config.evals.real.log_steps = 100
|
| 144 |
+
|
| 145 |
+
config.evals.v2 = dict(type='classification')
|
| 146 |
+
config.evals.v2.data = dict(name='imagenet_v2', split='test')
|
| 147 |
+
config.evals.v2.pp_fn = pp.format(lbl='label')
|
| 148 |
+
config.evals.v2.loss_name = config.loss
|
| 149 |
+
config.evals.v2.log_steps = 100
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_config(arg=None):
|
| 153 |
+
"""Config for adaptation."""
|
| 154 |
+
arg = bvcc.parse_arg(arg, model='vit', dataset='cifar10', crop='resmall_crop',
|
| 155 |
+
h_res=448, l_res=384, batch_size=512, fsdp=False,
|
| 156 |
+
runlocal=False)
|
| 157 |
+
config = mlc.ConfigDict()
|
| 158 |
+
|
| 159 |
+
config.input = {}
|
| 160 |
+
config.input.batch_size = arg.batch_size if not arg.runlocal else 8
|
| 161 |
+
config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 100
|
| 162 |
+
|
| 163 |
+
config.log_training_steps = 10
|
| 164 |
+
config.ckpt_steps = 1000
|
| 165 |
+
config.ckpt_timeout = 600
|
| 166 |
+
|
| 167 |
+
# Optimizer section
|
| 168 |
+
config.optax_name = 'big_vision.momentum_hp'
|
| 169 |
+
config.grad_clip_norm = 1.0
|
| 170 |
+
config.wd = None # That's our default, but just being explicit here!
|
| 171 |
+
config.loss = 'softmax_xent'
|
| 172 |
+
config.lr = 0.01
|
| 173 |
+
config.mixup = dict(p=0.0)
|
| 174 |
+
|
| 175 |
+
config.seed = 0
|
| 176 |
+
|
| 177 |
+
_set_dataset(config, arg.dataset, arg.crop, arg.h_res, arg.l_res)
|
| 178 |
+
|
| 179 |
+
_set_model(config, arg.model)
|
| 180 |
+
if arg.fsdp:
|
| 181 |
+
config.mesh = [('data', -1)]
|
| 182 |
+
config.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
| 183 |
+
config.sharding_rules = [('act_batch', ('data',))]
|
| 184 |
+
config.model.scan = True
|
| 185 |
+
|
| 186 |
+
return config
|
big_vision/configs/vit_i1k.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Pre-training ViT on ILSVRC-2012 as in https://arxiv.org/abs/2106.10270
|
| 17 |
+
|
| 18 |
+
This config does NOT include regularization (dropout, stochastic depth), which
|
| 19 |
+
was shown to help with B/32, B/16, L/16 models in the paper (Figure 4).
|
| 20 |
+
|
| 21 |
+
This configuration makes use of the "arg" to get_config to select which model
|
| 22 |
+
to run, so a few examples are given below:
|
| 23 |
+
|
| 24 |
+
Run training of a B/16 model:
|
| 25 |
+
|
| 26 |
+
big_vision.train \
|
| 27 |
+
--config big_vision/configs/vit_i1k.py:variant=B/16 \
|
| 28 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
|
| 29 |
+
|
| 30 |
+
Run training of a B/32 model with custom aug-strenght and 300ep:
|
| 31 |
+
|
| 32 |
+
big_vision.train \
|
| 33 |
+
--config big_vision/configs/vit_i1k.py:variant=B/32,aug=light1 \
|
| 34 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
| 35 |
+
--config.total_epochs 300
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import big_vision.configs.common as bvcc
|
| 39 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 40 |
+
import ml_collections as mlc
|
| 41 |
+
|
| 42 |
+
MIXUP_DEF = {
|
| 43 |
+
'none': dict(p=0.0, fold_in=None),
|
| 44 |
+
'light1': dict(p=0.0, fold_in=None),
|
| 45 |
+
'light2': dict(p=0.2, fold_in=None),
|
| 46 |
+
'medium1': dict(p=0.2, fold_in=None),
|
| 47 |
+
'medium2': dict(p=0.5, fold_in=None),
|
| 48 |
+
'strong1': dict(p=0.5, fold_in=None),
|
| 49 |
+
'strong2': dict(p=0.8, fold_in=None),
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
RANDAUG_DEF = {
|
| 53 |
+
'none': '',
|
| 54 |
+
'light1': 'randaug(2,0)', # Actually not nothing!
|
| 55 |
+
'light2': 'randaug(2,10)',
|
| 56 |
+
'medium1': 'randaug(2,15)',
|
| 57 |
+
'medium2': 'randaug(2,15)',
|
| 58 |
+
'strong1': 'randaug(2,20)',
|
| 59 |
+
'strong2': 'randaug(2,20)',
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_config(arg=None):
|
| 64 |
+
"""Config for training."""
|
| 65 |
+
arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug='')
|
| 66 |
+
config = mlc.ConfigDict()
|
| 67 |
+
|
| 68 |
+
config.seed = 0
|
| 69 |
+
config.total_epochs = 300
|
| 70 |
+
config.num_classes = 1000
|
| 71 |
+
config.loss = 'sigmoid_xent'
|
| 72 |
+
config.init_head_bias = -6.9
|
| 73 |
+
|
| 74 |
+
# If this gives a KeyError, lookup Fig4 of the paper and add an entry.
|
| 75 |
+
# Note, this here is a good average between 30ep and 300ep, sometimes you coud
|
| 76 |
+
# find a slightly better setting for either of them.
|
| 77 |
+
aug_setting = arg.aug or {
|
| 78 |
+
'Ti/16': 'light1',
|
| 79 |
+
'S/32': 'medium1',
|
| 80 |
+
'S/16': 'medium2',
|
| 81 |
+
'B/32': 'medium2',
|
| 82 |
+
'B/16': 'medium2',
|
| 83 |
+
'L/16': 'medium2',
|
| 84 |
+
}[arg.variant]
|
| 85 |
+
|
| 86 |
+
config.input = dict()
|
| 87 |
+
config.input.data = dict(
|
| 88 |
+
name='imagenet2012',
|
| 89 |
+
split='train[:99%]',
|
| 90 |
+
)
|
| 91 |
+
config.input.batch_size = 4096
|
| 92 |
+
config.input.cache = 'raw_data' if arg.runlocal else 'none' # Needs up to 120GB of RAM!
|
| 93 |
+
config.input.shuffle_buffer_size = 250_000
|
| 94 |
+
|
| 95 |
+
pp_common = (
|
| 96 |
+
'|value_range(-1, 1)'
|
| 97 |
+
'|onehot(1000, key="{lbl}", key_result="labels")'
|
| 98 |
+
'|keep("image", "labels")'
|
| 99 |
+
)
|
| 100 |
+
config.input.pp = (
|
| 101 |
+
'decode_jpeg_and_inception_crop(224)|flip_lr|' +
|
| 102 |
+
RANDAUG_DEF[aug_setting] +
|
| 103 |
+
pp_common.format(lbl='label')
|
| 104 |
+
)
|
| 105 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
|
| 106 |
+
|
| 107 |
+
# To continue using the near-defunct randaug op.
|
| 108 |
+
config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
|
| 109 |
+
|
| 110 |
+
# Aggressive pre-fetching because our models here are small, so we not only
|
| 111 |
+
# can afford it, but we also need it for the smallest models to not be
|
| 112 |
+
# bottle-necked by the input pipeline. Play around with it for -L models tho.
|
| 113 |
+
config.input.prefetch = 8
|
| 114 |
+
config.prefetch_to_device = 4
|
| 115 |
+
|
| 116 |
+
config.log_training_steps = 50
|
| 117 |
+
config.ckpt_steps = 1000
|
| 118 |
+
|
| 119 |
+
# Model section
|
| 120 |
+
config.model_name = 'vit'
|
| 121 |
+
config.model = dict(
|
| 122 |
+
variant=arg.variant,
|
| 123 |
+
rep_size=True,
|
| 124 |
+
pool_type='tok',
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Optimizer section
|
| 128 |
+
config.grad_clip_norm = 1.0
|
| 129 |
+
config.optax_name = 'scale_by_adam'
|
| 130 |
+
config.optax = dict(mu_dtype='bfloat16')
|
| 131 |
+
# The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560
|
| 132 |
+
# almost always behaves exactly like adam, but at a fraction of the memory
|
| 133 |
+
# cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a
|
| 134 |
+
# good idea to try it when you are memory-bound!
|
| 135 |
+
# config.optax_name = 'big_vision.scale_by_adafactor'
|
| 136 |
+
# A good flag to play with when hitting instabilities, is the following:
|
| 137 |
+
# config.optax = dict(beta2_cap=0.95)
|
| 138 |
+
|
| 139 |
+
config.lr = 0.001
|
| 140 |
+
config.wd = 0.0001
|
| 141 |
+
config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
|
| 142 |
+
|
| 143 |
+
config.mixup = MIXUP_DEF[aug_setting]
|
| 144 |
+
|
| 145 |
+
# Eval section
|
| 146 |
+
def get_eval(split, dataset='imagenet2012'):
|
| 147 |
+
return dict(
|
| 148 |
+
type='classification',
|
| 149 |
+
data=dict(name=dataset, split=split),
|
| 150 |
+
pp_fn=pp_eval.format(lbl='label'),
|
| 151 |
+
loss_name=config.loss,
|
| 152 |
+
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
|
| 153 |
+
cache='final_data' if arg.runlocal else 'none',
|
| 154 |
+
)
|
| 155 |
+
config.evals = {}
|
| 156 |
+
config.evals.train = get_eval('train[:2%]')
|
| 157 |
+
config.evals.minival = get_eval('train[99%:]')
|
| 158 |
+
config.evals.val = get_eval('validation')
|
| 159 |
+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
|
| 160 |
+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
|
| 161 |
+
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
|
| 162 |
+
|
| 163 |
+
config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
|
| 164 |
+
config.fewshot.log_steps = 10_000
|
| 165 |
+
|
| 166 |
+
# Make a few things much smaller for quick local debugging testruns.
|
| 167 |
+
if arg.runlocal:
|
| 168 |
+
config.input.shuffle_buffer_size = 10
|
| 169 |
+
config.input.batch_size = 8
|
| 170 |
+
config.input.cache_raw = False
|
| 171 |
+
config.evals.train.data.split = 'train[:16]'
|
| 172 |
+
config.evals.minival.data.split = 'train[:16]'
|
| 173 |
+
config.evals.val.data.split = 'validation[:16]'
|
| 174 |
+
config.evals.v2.data.split = 'test[:16]'
|
| 175 |
+
config.evals.real.data.split = 'validation[:16]'
|
| 176 |
+
|
| 177 |
+
return config
|
big_vision/configs/vit_i21k.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Pre-training ViT on ImageNet-21k as in https://arxiv.org/abs/2106.10270
|
| 17 |
+
|
| 18 |
+
This config relies on the Imagenet-21k tfds dataset, which is not yet
|
| 19 |
+
available publicly in TFDS. We intend to add the dataset to public TFDS soon,
|
| 20 |
+
and this config will then be runnable.
|
| 21 |
+
|
| 22 |
+
Note that regularization (dropout, stochastic depth) is not currently
|
| 23 |
+
implemented. This was not beneficial for ImageNet-21k pre-trainning.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import big_vision.configs.common as bvcc
|
| 27 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 28 |
+
import ml_collections as mlc
|
| 29 |
+
|
| 30 |
+
MIXUP_DEF = {
|
| 31 |
+
'none': dict(p=0.0, fold_in=None),
|
| 32 |
+
'light1': dict(p=0.0, fold_in=None),
|
| 33 |
+
'light2': dict(p=0.2, fold_in=None),
|
| 34 |
+
'medium1': dict(p=0.2, fold_in=None),
|
| 35 |
+
'medium2': dict(p=0.5, fold_in=None),
|
| 36 |
+
'strong1': dict(p=0.5, fold_in=None),
|
| 37 |
+
'strong2': dict(p=0.8, fold_in=None),
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
RANDAUG_DEF = {
|
| 41 |
+
'none': '',
|
| 42 |
+
'light1': 'randaug(2,0)', # Actually not nothing!
|
| 43 |
+
'light2': 'randaug(2,10)',
|
| 44 |
+
'medium1': 'randaug(2,15)',
|
| 45 |
+
'medium2': 'randaug(2,15)',
|
| 46 |
+
'strong1': 'randaug(2,20)',
|
| 47 |
+
'strong2': 'randaug(2,20)',
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_config(arg=None):
|
| 52 |
+
"""Config for training."""
|
| 53 |
+
arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug=None)
|
| 54 |
+
config = mlc.ConfigDict()
|
| 55 |
+
|
| 56 |
+
config.seed = 0
|
| 57 |
+
config.total_epochs = 300
|
| 58 |
+
config.num_classes = 21843
|
| 59 |
+
config.init_head_bias = -10.0
|
| 60 |
+
config.loss = 'sigmoid_xent'
|
| 61 |
+
|
| 62 |
+
# If this gives a KeyError, lookup Fig4 of the paper and add an entry.
|
| 63 |
+
# Note, this here is a good average between 30ep and 300ep, sometimes you coud
|
| 64 |
+
# find a slightly better setting for either of them.
|
| 65 |
+
aug_setting = {
|
| 66 |
+
'Ti/16': 'none',
|
| 67 |
+
'S/32': 'none',
|
| 68 |
+
'S/16': 'light1',
|
| 69 |
+
'B/32': 'light2',
|
| 70 |
+
'B/16': 'light2',
|
| 71 |
+
'L/16': 'medium2',
|
| 72 |
+
}[arg.variant]
|
| 73 |
+
|
| 74 |
+
config.input = dict()
|
| 75 |
+
config.input.data = dict(
|
| 76 |
+
name='imagenet21k',
|
| 77 |
+
split='full[51200:]',
|
| 78 |
+
)
|
| 79 |
+
config.input.batch_size = 4096
|
| 80 |
+
config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
|
| 81 |
+
|
| 82 |
+
pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
|
| 83 |
+
pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
|
| 84 |
+
pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
|
| 85 |
+
config.input.pp = f'decode_jpeg_and_inception_crop(224)|flip_lr|{RANDAUG_DEF[aug_setting]}' + pp_common_i21k
|
| 86 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)'
|
| 87 |
+
|
| 88 |
+
# To continue using the near-defunct randaug op.
|
| 89 |
+
config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
|
| 90 |
+
|
| 91 |
+
# Aggressive pre-fetching because our models here are small, so we not only
|
| 92 |
+
# can afford it, but we also need it for the smallest models to not be
|
| 93 |
+
# bottle-necked by the input pipeline. Play around with it for -L models tho.
|
| 94 |
+
config.input.prefetch = 8
|
| 95 |
+
config.prefetch_to_device = 4
|
| 96 |
+
|
| 97 |
+
config.log_training_steps = 50
|
| 98 |
+
config.ckpt_steps = 1000
|
| 99 |
+
|
| 100 |
+
# Model section
|
| 101 |
+
config.model_name = 'vit'
|
| 102 |
+
config.model = dict(variant=arg.variant, pool_type='gap', posemb='learn')
|
| 103 |
+
|
| 104 |
+
# Optimizer section
|
| 105 |
+
config.optax_name = 'scale_by_adam'
|
| 106 |
+
config.optax = dict(mu_dtype='bfloat16')
|
| 107 |
+
config.grad_clip_norm = 1.0
|
| 108 |
+
|
| 109 |
+
config.lr = 0.001
|
| 110 |
+
config.wd = 0.0001
|
| 111 |
+
config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
|
| 112 |
+
|
| 113 |
+
config.mixup = MIXUP_DEF[aug_setting]
|
| 114 |
+
|
| 115 |
+
# Evaluations on i21k itself.
|
| 116 |
+
def eval_i21k(split):
|
| 117 |
+
return dict(
|
| 118 |
+
type='classification',
|
| 119 |
+
data={**config.input.data, 'split': split},
|
| 120 |
+
pp_fn=pp_eval + pp_common_i21k,
|
| 121 |
+
loss_name=config.loss,
|
| 122 |
+
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
|
| 123 |
+
)
|
| 124 |
+
config.evals = {}
|
| 125 |
+
config.evals.test = eval_i21k('full[:25_600]')
|
| 126 |
+
config.evals.val = eval_i21k('full[25_600:51_200]')
|
| 127 |
+
config.evals.train = eval_i21k('full[51_200:76_800]')
|
| 128 |
+
|
| 129 |
+
# Few-shot evaluators
|
| 130 |
+
config.evals.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
|
| 131 |
+
config.evals.fewshot.log_steps = 25_000
|
| 132 |
+
|
| 133 |
+
# Make a few things much smaller for quick local debugging testruns.
|
| 134 |
+
if arg.runlocal:
|
| 135 |
+
config.input.shuffle_buffer_size = 10
|
| 136 |
+
config.input.batch_size = 8
|
| 137 |
+
config.evals.test.data.split = 'full[:16]'
|
| 138 |
+
config.evals.train.data.split = 'full[:16]'
|
| 139 |
+
config.evals.val.data.split = 'full[:16]'
|
| 140 |
+
config.evals.i1k_val.data.split = 'validation[:16]'
|
| 141 |
+
config.evals.i1k_v2.data.split = 'test[:16]'
|
| 142 |
+
config.evals.i1k_a.data.split = 'test[:16]'
|
| 143 |
+
config.evals.i1k_r.data.split = 'test[:16]'
|
| 144 |
+
|
| 145 |
+
return config
|
big_vision/configs/vit_s16_i1k.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Pre-training ViT-S/16 on ILSVRC-2012 following https://arxiv.org/abs/2205.01580.
|
| 17 |
+
|
| 18 |
+
This should take 6-7h to finish 90ep on a TPU-v3-8 and reach 76.5%,
|
| 19 |
+
see the tech report for more details.
|
| 20 |
+
|
| 21 |
+
Command to run:
|
| 22 |
+
|
| 23 |
+
big_vision.train \
|
| 24 |
+
--config big_vision/configs/vit_s16_i1k.py \
|
| 25 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
|
| 26 |
+
|
| 27 |
+
To run for 300ep, add `--config.total_epochs 300` to the command.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import ml_collections as mlc
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_config():
|
| 34 |
+
"""Config for training."""
|
| 35 |
+
config = mlc.ConfigDict()
|
| 36 |
+
|
| 37 |
+
config.seed = 0
|
| 38 |
+
config.total_epochs = 90
|
| 39 |
+
config.num_classes = 1000
|
| 40 |
+
config.loss = 'softmax_xent'
|
| 41 |
+
|
| 42 |
+
config.input = {}
|
| 43 |
+
config.input.data = dict(
|
| 44 |
+
name='imagenet2012',
|
| 45 |
+
split='train[:99%]',
|
| 46 |
+
)
|
| 47 |
+
config.input.batch_size = 1024
|
| 48 |
+
config.input.cache_raw = True # Needs up to 120GB of RAM!
|
| 49 |
+
config.input.shuffle_buffer_size = 250_000
|
| 50 |
+
|
| 51 |
+
pp_common = (
|
| 52 |
+
'|value_range(-1, 1)'
|
| 53 |
+
'|onehot(1000, key="{lbl}", key_result="labels")'
|
| 54 |
+
'|keep("image", "labels")'
|
| 55 |
+
)
|
| 56 |
+
config.input.pp = (
|
| 57 |
+
'decode_jpeg_and_inception_crop(224)|flip_lr|randaug(2,10)' +
|
| 58 |
+
pp_common.format(lbl='label')
|
| 59 |
+
)
|
| 60 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
|
| 61 |
+
|
| 62 |
+
# To continue using the near-defunct randaug op.
|
| 63 |
+
config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
|
| 64 |
+
|
| 65 |
+
config.log_training_steps = 50
|
| 66 |
+
config.ckpt_steps = 1000
|
| 67 |
+
|
| 68 |
+
# Model section
|
| 69 |
+
config.model_name = 'vit'
|
| 70 |
+
config.model = dict(
|
| 71 |
+
variant='S/16',
|
| 72 |
+
rep_size=True,
|
| 73 |
+
pool_type='gap',
|
| 74 |
+
posemb='sincos2d',
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Optimizer section
|
| 78 |
+
config.grad_clip_norm = 1.0
|
| 79 |
+
config.optax_name = 'scale_by_adam'
|
| 80 |
+
config.optax = dict(mu_dtype='bfloat16')
|
| 81 |
+
|
| 82 |
+
config.lr = 0.001
|
| 83 |
+
config.wd = 0.0001
|
| 84 |
+
config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
|
| 85 |
+
|
| 86 |
+
config.mixup = dict(p=0.2, fold_in=None)
|
| 87 |
+
|
| 88 |
+
# Eval section
|
| 89 |
+
def get_eval(split, dataset='imagenet2012'):
|
| 90 |
+
return dict(
|
| 91 |
+
type='classification',
|
| 92 |
+
data=dict(name=dataset, split=split),
|
| 93 |
+
pp_fn=pp_eval.format(lbl='label'),
|
| 94 |
+
loss_name=config.loss,
|
| 95 |
+
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
|
| 96 |
+
)
|
| 97 |
+
config.evals = {}
|
| 98 |
+
config.evals.train = get_eval('train[:2%]')
|
| 99 |
+
config.evals.minival = get_eval('train[99%:]')
|
| 100 |
+
config.evals.val = get_eval('validation')
|
| 101 |
+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
|
| 102 |
+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
|
| 103 |
+
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
|
| 104 |
+
|
| 105 |
+
return config
|
big_vision/datasets/ai2d/ai2d.py
ADDED
|
File without changes
|
big_vision/datasets/aokvqa/aokvqa.py
ADDED
|
File without changes
|
big_vision/datasets/chartqa/chartqa.py
ADDED
|
File without changes
|
big_vision/datasets/coco35l/coco35l.py
ADDED
|
File without changes
|
big_vision/datasets/core.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Core data functions, dispatch calls to the requested dataset."""
|
| 16 |
+
import importlib
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Note: intentionally not using ABC to avoid forcing implementation of every
|
| 20 |
+
# method, since one can imagine train-only datasets for example.
|
| 21 |
+
class DataSource:
|
| 22 |
+
"""The API that any data source should implement."""
|
| 23 |
+
|
| 24 |
+
def get_tfdata(self, ordered, *, process_split=True, allow_cache=True):
|
| 25 |
+
"""Creates this data object as a tf.data.Dataset.
|
| 26 |
+
|
| 27 |
+
This will be called separately in each process, and it is up to the dataset
|
| 28 |
+
implementation to shard it accordingly if desired!
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
ordered: if True, the dataset should use deterministic ordering, if False
|
| 32 |
+
it may have undefined ordering. Think of True == val, False == train.
|
| 33 |
+
process_split: if False then every process receives the entire dataset
|
| 34 |
+
(e.g. for evaluators running in a single process).
|
| 35 |
+
allow_cache: whether to allow caching the opened data or not.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
A tf.data.Dataset object.
|
| 39 |
+
|
| 40 |
+
Raises:
|
| 41 |
+
RuntimeError: if not implemented by the dataset, but called.
|
| 42 |
+
"""
|
| 43 |
+
raise RuntimeError("not implemented for {self.__class__.__name__}")
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def total_examples(self):
|
| 47 |
+
"""Returns number of examples in the dataset, regardless of sharding."""
|
| 48 |
+
raise RuntimeError("not implemented for {self.__class__.__name__}")
|
| 49 |
+
|
| 50 |
+
def num_examples_per_process(self):
|
| 51 |
+
"""Returns a list of the numer of examples for each process.
|
| 52 |
+
|
| 53 |
+
This is only needed for datasets that should go through make_for_inference.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Returns a list of the numer of examples for each process.
|
| 57 |
+
|
| 58 |
+
Ideally, this would always be `[total() / nprocess] * nprocess`, but in
|
| 59 |
+
reality we can almost never perfectly shard a dataset across arbitrary
|
| 60 |
+
number of processes.
|
| 61 |
+
|
| 62 |
+
One alternative option that can work in some cases is to not even shard
|
| 63 |
+
the dataset and thus return `[num_examples()] * nprocess.
|
| 64 |
+
|
| 65 |
+
Raises:
|
| 66 |
+
RuntimeError: if not implemented by the dataset, but called.
|
| 67 |
+
"""
|
| 68 |
+
raise RuntimeError("not implemented for {self.__class__.__name__}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get(name, **kw):
|
| 72 |
+
if name.startswith("bv:"):
|
| 73 |
+
mod = importlib.import_module(f"big_vision.datasets.{name[3:]}")
|
| 74 |
+
return mod.DataSource(**kw)
|
| 75 |
+
else:
|
| 76 |
+
mod = importlib.import_module("big_vision.datasets.tfds")
|
| 77 |
+
return mod.DataSource(name, **kw)
|
big_vision/datasets/countbenchqa/countbenchqa.py
ADDED
|
File without changes
|
big_vision/datasets/docvqa/docvqa.py
ADDED
|
File without changes
|
big_vision/datasets/gqa/gqa.py
ADDED
|
File without changes
|
big_vision/datasets/imagenet/class_names.py
ADDED
|
File without changes
|
big_vision/datasets/infovqa/infovqa.py
ADDED
|
File without changes
|
big_vision/datasets/jsonl.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Simple data input from .jsonl files."""
|
| 16 |
+
|
| 17 |
+
import hashlib
|
| 18 |
+
import json
|
| 19 |
+
from multiprocessing.pool import ThreadPool
|
| 20 |
+
import os
|
| 21 |
+
import tempfile
|
| 22 |
+
import urllib.request
|
| 23 |
+
|
| 24 |
+
from absl import logging
|
| 25 |
+
import big_vision.datasets.core as ds_core
|
| 26 |
+
import jax
|
| 27 |
+
import numpy as np
|
| 28 |
+
import overrides
|
| 29 |
+
import tensorflow as tf
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def cached_download(url, dest=None, verbose=True):
|
| 33 |
+
"""Download `url` to local file and return path to that, but with caching."""
|
| 34 |
+
# NOTE: there is a small chance of saving corrupted data if the process is
|
| 35 |
+
# interrupted in the middle of writing the file. Then, reading in the input
|
| 36 |
+
# pipeline will fail, and the fix is to nuke the temp folder.
|
| 37 |
+
|
| 38 |
+
# Compute a temp name based on the URL, so we can check if we already
|
| 39 |
+
# downloaded it before.
|
| 40 |
+
dest = dest or os.path.join(tempfile.gettempdir(), "bv")
|
| 41 |
+
os.makedirs(dest, exist_ok=True)
|
| 42 |
+
dest = os.path.join(dest, hashlib.md5(url.encode()).hexdigest())
|
| 43 |
+
|
| 44 |
+
# NOTE: we should use last-modified header to know whether to re-download.
|
| 45 |
+
if os.path.isfile(dest):
|
| 46 |
+
return dest
|
| 47 |
+
|
| 48 |
+
if verbose:
|
| 49 |
+
print(f"\rRetrieving {url} into {dest}", end="", flush=True)
|
| 50 |
+
|
| 51 |
+
with urllib.request.urlopen(url) as f:
|
| 52 |
+
data = f.read()
|
| 53 |
+
with open(dest, "wb+") as f:
|
| 54 |
+
f.write(data)
|
| 55 |
+
return dest
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class DataSource(ds_core.DataSource):
|
| 59 |
+
""".jsonl DataSource."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, fname, *, fopen_keys=(), download_keys=(),
|
| 62 |
+
start=0, stop=float("inf")):
|
| 63 |
+
"""Create data-source that's jsonl + data files (eg images).
|
| 64 |
+
|
| 65 |
+
This correctly supports multi-host in that each host only reads a subset of
|
| 66 |
+
the dataset automatically. However, currently, all hosts download all items
|
| 67 |
+
if `download_keys` is specified. TODO: b/lbeyer - This can be improved.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
fname: str, the path to the jsonl file that holds the dataset.
|
| 71 |
+
fopen_keys: collection of str or dict, the keys in the dataset whose
|
| 72 |
+
string value actually is a file-path that should be opened and read,
|
| 73 |
+
and its content is what goes into the batch (eg image filenames
|
| 74 |
+
commonly ["image"]).
|
| 75 |
+
If a dict, the values are folders prefixed to the filenames.
|
| 76 |
+
Supports gs:// for reading from buckets.
|
| 77 |
+
download_keys: collection of str, the keys in the dataset whose string
|
| 78 |
+
value actually is a URL from which the file should be downloaded first.
|
| 79 |
+
files are downloaded to a persistent tmp folder using the URL hash as
|
| 80 |
+
filename. If the file already exists, the download is skipped.
|
| 81 |
+
Must be a subset of `fopen_keys`.
|
| 82 |
+
start: int, index of the first row to use; use for slicing the data.
|
| 83 |
+
stop: int or inf, index of the row after the last one to use.
|
| 84 |
+
|
| 85 |
+
Note:
|
| 86 |
+
This simple data input does not allow for nested/hierarchical values,
|
| 87 |
+
or in any way more complicated values like vectors. Use TFDS for that.
|
| 88 |
+
|
| 89 |
+
The way start/stop arguments are used is as in list slicing[start:stop].
|
| 90 |
+
"""
|
| 91 |
+
self.examples = []
|
| 92 |
+
|
| 93 |
+
with tf.io.gfile.GFile(fname) as f:
|
| 94 |
+
for i, line in enumerate(f):
|
| 95 |
+
if (start or 0) <= i < (stop or float("inf")):
|
| 96 |
+
try:
|
| 97 |
+
self.examples.append(json.loads(line))
|
| 98 |
+
except json.decoder.JSONDecodeError as e:
|
| 99 |
+
raise ValueError(f"Invalid JSON in line {i}:\n{line}") from e
|
| 100 |
+
|
| 101 |
+
if download_keys:
|
| 102 |
+
for k in download_keys:
|
| 103 |
+
assert k in fopen_keys, (
|
| 104 |
+
f"{k} in download_keys but missing from fopen_keys {fopen_keys}")
|
| 105 |
+
|
| 106 |
+
# TODO: b/lbeyer - use info from trainer instead, move that to utils.
|
| 107 |
+
logging.info( # pylint: disable=logging-fstring-interpolation
|
| 108 |
+
f"\u001b[33mNOTE\u001b[0m: Downloading {download_keys} "
|
| 109 |
+
f"for dataset {fname} ({len(self.examples)} examples) ...")
|
| 110 |
+
|
| 111 |
+
def _dl_one(ex):
|
| 112 |
+
for k in download_keys:
|
| 113 |
+
ex[k] = cached_download(ex[k])
|
| 114 |
+
|
| 115 |
+
ThreadPool(100).map(_dl_one, self.examples)
|
| 116 |
+
print("Done")
|
| 117 |
+
logging.info("\u001b[33mNOTE\u001b[0m: Done downloading.")
|
| 118 |
+
|
| 119 |
+
# Normalize.
|
| 120 |
+
if isinstance(fopen_keys, (list, tuple)):
|
| 121 |
+
self.fopen_keys = {k: "" for k in fopen_keys}
|
| 122 |
+
else:
|
| 123 |
+
self.fopen_keys = fopen_keys or {}
|
| 124 |
+
|
| 125 |
+
# We need to apply fopen path prefix here already, because doing so while
|
| 126 |
+
# actually reading the files in TF, things are symbolic :(
|
| 127 |
+
for ex in self.examples:
|
| 128 |
+
for k, dirname in self.fopen_keys.items():
|
| 129 |
+
ex[k] = os.path.join(dirname, ex[k])
|
| 130 |
+
|
| 131 |
+
def _indices(self, *, process_split=True, process_index=None):
|
| 132 |
+
indices = np.arange(len(self.examples))
|
| 133 |
+
|
| 134 |
+
if not process_split:
|
| 135 |
+
return list(indices)
|
| 136 |
+
|
| 137 |
+
pid = jax.process_index() if process_index is None else process_index
|
| 138 |
+
return list(np.array_split(indices, jax.process_count())[pid])
|
| 139 |
+
|
| 140 |
+
@overrides.overrides
|
| 141 |
+
def get_tfdata(self, ordered=False, *, process_split=True, allow_cache=True):
|
| 142 |
+
del allow_cache # We don't cache anything anyways.
|
| 143 |
+
assert not process_split or len(self.examples) >= jax.process_count(), (
|
| 144 |
+
"Process splitting the data with fewer examples than processes!?")
|
| 145 |
+
|
| 146 |
+
my_idxs = self._indices(process_split=process_split)
|
| 147 |
+
if not ordered:
|
| 148 |
+
np.random.shuffle(my_idxs)
|
| 149 |
+
|
| 150 |
+
dataset = tf.data.Dataset.from_generator(
|
| 151 |
+
generator=lambda: ({"id": str(i), **self.examples[i]} for i in my_idxs),
|
| 152 |
+
output_signature={
|
| 153 |
+
"id": _guess_signature("0"),
|
| 154 |
+
**{k: _guess_signature(v) for k, v in self.examples[0].items()},
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
def _read_files(example):
|
| 158 |
+
for k in self.fopen_keys:
|
| 159 |
+
example[k] = tf.io.read_file(example[k])
|
| 160 |
+
return example
|
| 161 |
+
dataset = dataset.map(_read_files)
|
| 162 |
+
|
| 163 |
+
return dataset
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
@overrides.overrides
|
| 167 |
+
def total_examples(self):
|
| 168 |
+
return len(self.examples)
|
| 169 |
+
|
| 170 |
+
@overrides.overrides
|
| 171 |
+
def num_examples_per_process(self):
|
| 172 |
+
return [len(self._indices(process_index=pid))
|
| 173 |
+
for pid in range(jax.process_count())]
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _guess_signature(value):
|
| 177 |
+
return tf.TensorSpec.from_tensor(tf.constant(value))
|
big_vision/datasets/nocaps/nocaps.py
ADDED
|
File without changes
|
big_vision/datasets/okvqa/okvqa.py
ADDED
|
File without changes
|
big_vision/datasets/pope/pope.py
ADDED
|
File without changes
|
big_vision/datasets/refcoco/refcoco.py
ADDED
|
File without changes
|
big_vision/datasets/rsvqa_hr/rsvqa_hr.py
ADDED
|
File without changes
|
big_vision/datasets/rsvqa_lr/rsvqa_lr.py
ADDED
|
File without changes
|
big_vision/datasets/scicap/scicap.py
ADDED
|
File without changes
|
big_vision/datasets/science_qa/science_qa.py
ADDED
|
File without changes
|
big_vision/datasets/screen2words/screen2words.py
ADDED
|
File without changes
|
big_vision/datasets/sequence_packing.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Packed Sequence Op."""
|
| 16 |
+
|
| 17 |
+
# Forked from
|
| 18 |
+
# https://github.com/google/maxtext/blob/main/MaxText/sequence_packing.py.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
from typing import Dict, Optional, List, Union
|
| 22 |
+
|
| 23 |
+
from flax import traverse_util
|
| 24 |
+
import tensorflow as tf
|
| 25 |
+
|
| 26 |
+
AUTOTUNE = tf.data.experimental.AUTOTUNE
|
| 27 |
+
FLATTEN_SEPARATOR = "<|sep|>"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def pack_dataset(
|
| 31 |
+
dataset: tf.data.Dataset,
|
| 32 |
+
batch_size: int | None,
|
| 33 |
+
key2length: Union[int, Dict[str, int]],
|
| 34 |
+
keys: Optional[List[str | tuple[str, ...]]] = None) -> tf.data.Dataset:
|
| 35 |
+
"""Creates a 'packed' version of a dataset on-the-fly.
|
| 36 |
+
|
| 37 |
+
Wrap `tensorflow.grain` ops.
|
| 38 |
+
|
| 39 |
+
This is meant to replace the irritation of having to create a separate
|
| 40 |
+
"packed" version of a dataset to train efficiently on TPU.
|
| 41 |
+
Each example in the output dataset represents several examples in the
|
| 42 |
+
input dataset.
|
| 43 |
+
|
| 44 |
+
For each key in the input dataset, two additional keys are created:
|
| 45 |
+
<key>_segment_ids: an int32 tensor identifying the parts
|
| 46 |
+
representing the original example.
|
| 47 |
+
<key>_positions: an int32 tensor identifying the position within the original
|
| 48 |
+
example.
|
| 49 |
+
|
| 50 |
+
Example:
|
| 51 |
+
Two input examples get combined to form an output example.
|
| 52 |
+
The input examples are:
|
| 53 |
+
{"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
|
| 54 |
+
{"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
|
| 55 |
+
The output example is:
|
| 56 |
+
{
|
| 57 |
+
"inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
|
| 58 |
+
"inputs_seg": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
|
| 59 |
+
"inputs_pos": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
|
| 60 |
+
"targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
|
| 61 |
+
"targets_seg": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
|
| 62 |
+
"targets_pos": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
|
| 63 |
+
}
|
| 64 |
+
0 represents padding in both the inputs and the outputs.
|
| 65 |
+
Sequences in the incoming examples are truncated to length "length", and the
|
| 66 |
+
sequences in the output examples all have fixed (padded) length "length".
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
dataset: A `tf.data.Dataset`.
|
| 70 |
+
batch_size: Batch size of the packed dataset.
|
| 71 |
+
key2length: An integer, or a dict from feature-key to integer.
|
| 72 |
+
keys: A list of strings (e.g. ["inputs", "targets"]).
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
A `tf.data.Dataset`.
|
| 76 |
+
"""
|
| 77 |
+
raise ValueError("Not implemented in OSS yet.")
|
big_vision/datasets/stvqa/stvqa.py
ADDED
|
File without changes
|
big_vision/datasets/tallyqa/tallyqa.py
ADDED
|
File without changes
|
big_vision/datasets/textcaps/textcaps.py
ADDED
|
File without changes
|