fcxfcx commited on
Commit
742a3d1
·
verified ·
1 Parent(s): 6f7b1aa

Upload 549 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. CONTRIBUTING.md +32 -0
  3. IOU_test.py +21 -0
  4. LICENSE +202 -0
  5. README.md +217 -0
  6. __pycache__/owlv2_helper.cpython-310.pyc +0 -0
  7. __pycache__/owlv2_helper_functions.cpython-310.pyc +0 -0
  8. auto_bbox.py +266 -0
  9. big_vision/.gitignore +1 -0
  10. big_vision/CONTRIBUTING.md +26 -0
  11. big_vision/LICENSE +201 -0
  12. big_vision/README.md +499 -0
  13. big_vision/__init__.py +0 -0
  14. big_vision/__pycache__/__init__.cpython-310.pyc +0 -0
  15. big_vision/__pycache__/utils.cpython-310.pyc +0 -0
  16. big_vision/configs/__init__.py +0 -0
  17. big_vision/configs/bit_i1k.py +102 -0
  18. big_vision/configs/bit_i21k.py +85 -0
  19. big_vision/configs/common.py +188 -0
  20. big_vision/configs/common_fewshot.py +60 -0
  21. big_vision/configs/load_and_eval.py +143 -0
  22. big_vision/configs/mlp_mixer_i1k.py +120 -0
  23. big_vision/configs/transfer.py +186 -0
  24. big_vision/configs/vit_i1k.py +177 -0
  25. big_vision/configs/vit_i21k.py +145 -0
  26. big_vision/configs/vit_s16_i1k.py +105 -0
  27. big_vision/datasets/ai2d/ai2d.py +0 -0
  28. big_vision/datasets/aokvqa/aokvqa.py +0 -0
  29. big_vision/datasets/chartqa/chartqa.py +0 -0
  30. big_vision/datasets/coco35l/coco35l.py +0 -0
  31. big_vision/datasets/core.py +77 -0
  32. big_vision/datasets/countbenchqa/countbenchqa.py +0 -0
  33. big_vision/datasets/docvqa/docvqa.py +0 -0
  34. big_vision/datasets/gqa/gqa.py +0 -0
  35. big_vision/datasets/imagenet/class_names.py +0 -0
  36. big_vision/datasets/infovqa/infovqa.py +0 -0
  37. big_vision/datasets/jsonl.py +177 -0
  38. big_vision/datasets/nocaps/nocaps.py +0 -0
  39. big_vision/datasets/okvqa/okvqa.py +0 -0
  40. big_vision/datasets/pope/pope.py +0 -0
  41. big_vision/datasets/refcoco/refcoco.py +0 -0
  42. big_vision/datasets/rsvqa_hr/rsvqa_hr.py +0 -0
  43. big_vision/datasets/rsvqa_lr/rsvqa_lr.py +0 -0
  44. big_vision/datasets/scicap/scicap.py +0 -0
  45. big_vision/datasets/science_qa/science_qa.py +0 -0
  46. big_vision/datasets/screen2words/screen2words.py +0 -0
  47. big_vision/datasets/sequence_packing.py +77 -0
  48. big_vision/datasets/stvqa/stvqa.py +0 -0
  49. big_vision/datasets/tallyqa/tallyqa.py +0 -0
  50. big_vision/datasets/textcaps/textcaps.py +0 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ckpts/clip_vit_l14_with_masks_6c17944 filter=lfs diff=lfs merge=lfs -text
37
+ ckpts/owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05_209b65b filter=lfs diff=lfs merge=lfs -text
38
+ ckpts/owl2-l14-1008-st-ngrams-ft-lvisbase-ens-cold-weight-04_8ca674c filter=lfs diff=lfs merge=lfs -text
39
+ images/scenic_design.jpg filter=lfs diff=lfs merge=lfs -text
40
+ images/scenic_logo.jpg filter=lfs diff=lfs merge=lfs -text
CONTRIBUTING.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Contribute
2
+
3
+ Scenic is a platform used for developing new methods and ideas by Google
4
+ researchers, mostly around attention-based models for computer vision or
5
+ multi-modal applications. We encourage forking the repository and continued
6
+ development. We welcome suggestions and contributions to improving Scenic.
7
+ There are a few small guidelines you need to follow.
8
+
9
+ ## Contributor License Agreement
10
+
11
+ Contributions to this project must be accompanied by a Contributor License
12
+ Agreement (CLA). You (or your employer) retain the copyright to your
13
+ contribution; this simply gives us permission to use and redistribute your
14
+ contributions as part of the project. Head over to
15
+ <https://cla.developers.google.com/> to see your current agreements on file or
16
+ to sign a new one.
17
+
18
+ You generally only need to submit a CLA once, so if you've already submitted one
19
+ (even if it was for a different project), you probably don't need to do it
20
+ again.
21
+
22
+ ## Code Reviews
23
+
24
+ All submissions, including submissions by project members, require review. We
25
+ use GitHub pull requests for this purpose. Consult
26
+ [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
27
+ information on using pull requests.
28
+
29
+ ## Community Guidelines
30
+
31
+ This project follows
32
+ [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
IOU_test.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from owlv2_helper_functions import get_iou, boxes_filter
2
+
3
+ boxes = [
4
+ (128.56, 4.57, 732.52, 476.05),
5
+ (569.65, 185.71, 740.31, 244.76),
6
+ (569.65, 185.71, 740.31, 244.76),
7
+ (569.65, 185.71, 740.31, 244.76),
8
+ (101.99, 99.00, 720.12, 88.63),
9
+ ]
10
+
11
+ scores = [1.0, 0.99, 0.89, 1.0, 0.99]
12
+
13
+ instances = ['cat', 'dog', 'dog', 'tiger', 'cat']
14
+
15
+
16
+
17
+ pred_bboxes, pred_scores, instances = boxes_filter(boxes, scores, instances)
18
+
19
+ print(pred_bboxes)
20
+ print(pred_scores)
21
+ print(instances)
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scenic
2
+ <div style="text-align: left">
3
+ <img align="right" src="https://raw.githubusercontent.com/google-research/scenic/main/images/scenic_logo.png" width="200" alt="scenic logo"></img>
4
+ </div>
5
+
6
+ *Scenic* is a codebase with a focus on research around attention-based models
7
+ for computer vision. Scenic has been successfully used to develop
8
+ classification, segmentation, and detection models for multiple modalities
9
+ including images, video, audio, and multimodal combinations of them.
10
+
11
+ More precisely, *Scenic* is a (i) set of shared light-weight libraries solving
12
+ tasks commonly encountered tasks when training large-scale (i.e. multi-device,
13
+ multi-host) vision models; and (ii) several *projects* containing fully
14
+ fleshed out problem-specific training and evaluation loops using these
15
+ libraries.
16
+
17
+ Scenic is developed in [JAX](https://github.com/jax-ml/jax) and uses
18
+ [Flax](https://github.com/google/flax).
19
+
20
+ ### Contents
21
+ * [What we offer](#what-we-offer)
22
+ * [SOTA models and baselines in Scenic](#sota-models-and-baselines-in-scenic)
23
+ * [Philosophy](#philosophy)
24
+ * [Getting started](#getting-started)
25
+ * [Scenic component design](#scenic-component-design)
26
+ * [Citing Scenic](#citing-scenic)
27
+
28
+ ## What we offer
29
+ Among others *Scenic* provides
30
+
31
+ * Boilerplate code for launching experiments, summary writing, logging,
32
+ profiling, etc;
33
+ * Optimized training and evaluation loops, losses, metrics, bi-partite matchers,
34
+ etc;
35
+ * Input-pipelines for popular vision datasets;
36
+ * [Baseline models](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines#scenic-baseline-models),
37
+ including strong non-attentional baselines.
38
+
39
+
40
+ ## SOTA models and baselines in *Scenic*
41
+ There are some SOTA models and baselines in Scenic which were either developed
42
+ using Scenic, or have been reimplemented in Scenic:
43
+
44
+ Projects that were developed in Scenic or used it for their experiments:
45
+
46
+ * [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691)
47
+ * [OmniNet: Omnidirectional Representations from Transformers](https://arxiv.org/abs/2103.01075)
48
+ * [Attention Bottlenecks for Multimodal Fusion](https://arxiv.org/abs/2107.00135)
49
+ * [TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?](https://arxiv.org/abs/2106.11297)
50
+ * [Exploring the Limits of Large Scale Pre-training](https://arxiv.org/abs/2110.02095)
51
+ * [The Efficiency Misnomer](https://arxiv.org/abs/2110.12894)
52
+ * [Discrete Representations Strengthen Vision Transformer Robustness](https://arxiv.org/abs/2111.10493)
53
+ * [Pyramid Adversarial Training Improves ViT Performance](https://arxiv.org/abs/2111.15121)
54
+ * [VUT: Versatile UI Transformer for Multi-Modal Multi-Task User Interface Modeling](https://arxiv.org/abs/2112.05692)
55
+ * [CLAY: Learning to Denoise Raw Mobile UI Layouts for Improving Datasets at Scale](https://arxiv.org/abs/2201.04100)
56
+ * [Zero-Shot Text-Guided Object Generation with Dream Fields](https://arxiv.org/abs/2112.01455)
57
+ * [Multiview Transformers for Video Recognition](https://arxiv.org/abs/2201.04288)
58
+ * [PolyViT: Co-training Vision Transformers on Images, Videos and Audio](https://arxiv.org/abs/2111.12993)
59
+ * [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230)
60
+ * [Learning with Neighbor Consistency for Noisy Labels](https://arxiv.org/abs/2202.02200)
61
+ * [Token Turing Machines](https://arxiv.org/pdf/2211.09119.pdf)
62
+ * [Vid2Seq: Large-Scale Pretraining of a Visual Language Model for Dense Video Captioning](https://arxiv.org/pdf/2302.14115.pdf)
63
+ * [AVATAR: Unconstrained Audiovisual Speech Recognition](https://arxiv.org/abs/2206.07684)
64
+ * [Adaptive Computation with Elastic Input Sequence](https://arxiv.org/abs/2301.13195)
65
+ * [Location-Aware Self-Supervised Transformers for Semantic Segmentation](https://arxiv.org/abs/2212.02400)
66
+ * [How can objects help action recognition?](https://openaccess.thecvf.com/content/CVPR2023/html/Zhou_How_Can_Objects_Help_Action_Recognition_CVPR_2023_paper.html)
67
+ * [Verbs in Action: Improving verb understanding in video-language models](https://arxiv.org/abs/2304.06708)
68
+ * [Unified Visual Relationship Detection with Vision and Language Models](https://arxiv.org/abs/2303.08998)
69
+ * [UnLoc: A Unified Framework for Video Localization Tasks](https://arxiv.org/abs/2308.11062)
70
+ * [REVEAL: Retrieval-Augmented Visual-Language Pre-Training with Multi-Source Multimodal Knowledge Memory](https://arxiv.org/abs/2212.05221)
71
+ * [Audiovisual Masked Autoencoders](https://arxiv.org/abs/2212.05922)
72
+ * [MatFormer: Nested Transformer for Elastic Inference](https://arxiv.org/abs/2310.07707)
73
+ * [Pixel Aligned Language Models](https://arxiv.org/abs/2312.09237)
74
+ * [A Generative Approach for Wikipedia-Scale Visual Entity Recognition](https://arxiv.org/abs/2403.02041)
75
+ * [Streaming Dense Video Captioning](https://arxiv.org/abs/2404.01297)
76
+ * [Dense Video Object Captioning from Disjoint Supervision](https://arxiv.org/abs/2306.11729)
77
+
78
+ More information can be found in [projects](https://github.com/google-research/scenic/tree/main/scenic/projects#list-of-projects-hosted-in-scenic).
79
+
80
+ Baselines that were reproduced in Scenic:
81
+
82
+ * [(ViT) An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)
83
+ * [(DETR) End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872)
84
+ * [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159)
85
+ * [(CLIP) Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
86
+ * [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601)
87
+ * [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)
88
+ * [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)
89
+ * [Big Transfer (BiT): General Visual Representation Learning](https://arxiv.org/abs/1912.11370)
90
+ * [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)
91
+ * [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)
92
+ * [PCT: Point Cloud Transformer](https://arxiv.org/abs/2012.09688)
93
+ * [Universal Transformers](https://arxiv.org/abs/1807.03819)
94
+ * [PonderNet](https://arxiv.org/abs/2107.05407)
95
+ * [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)
96
+ * [Rethinking Attention with Performers](https://arxiv.org/abs/2009.14794)
97
+ * [(CenterNet) Objects as Points](https://arxiv.org/abs/1904.07850)
98
+ * [(SAM) Segment Anything](https://arxiv.org/abs/2304.02643)
99
+
100
+
101
+ More information can be found in [baseline models](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines#scenic-baseline-models).
102
+
103
+ <a name="philosophy"></a>
104
+ ## Philosophy
105
+ *Scenic* aims to facilitate rapid prototyping of large-scale vision models. To
106
+ keep the code simple to understand and extend we prefer *forking and
107
+ copy-pasting over adding complexity or increasing abstraction*. Only when
108
+ functionality proves to be widely useful across many models and tasks it may be
109
+ upstreamed to Scenic's shared libraries.
110
+
111
+
112
+ <a name="getting_start"></a>
113
+ ## Getting started
114
+ * See `projects/baselines/README.md` for a walk-through baseline models and
115
+ instructions on how to run the code.
116
+ * If you would like to contribute to *Scenic*, please check out the
117
+ [Philisophy](#philosophy), [Code structure](#code_structure) and
118
+ [Contributing](CONTRIBUTING.md) sections.
119
+ Should your contribution be a part of the shared libraries, please send us a
120
+ pull request!
121
+
122
+
123
+ ### Quickstart
124
+ You will need Python 3.9 or later. Download the code from GitHub
125
+
126
+ ```shell
127
+ $ git clone https://github.com/google-research/scenic.git
128
+ $ cd scenic
129
+ $ pip install .
130
+ ```
131
+
132
+ and run training for ViT on ImageNet:
133
+
134
+ ```shell
135
+ $ python scenic/main.py -- \
136
+ --config=scenic/projects/baselines/configs/imagenet/imagenet_vit_config.py \
137
+ --workdir=./
138
+ ```
139
+
140
+ Note that for specific projects and baselines, you might need to install extra
141
+ packages that are mentioned in their `README.md` or `requirements.txt` files.
142
+
143
+ [Here](https://colab.research.google.com/github/google-research/scenic/blob/main/scenic/common_lib/colabs/scenic_playground.ipynb)
144
+ is also a minimal colab to train a simple feed-forward model using Scenic.
145
+
146
+ <a name="code_structure"></a>
147
+ ## Scenic component design
148
+ Scenic is designed to propose different levels of abstraction, to support
149
+ hosting projects that only require changing hyper-parameters by defining config
150
+ files, to those that need customization on the input pipeline, model
151
+ architecture, losses and metrics, and the training loop. To make this happen,
152
+ the code in Scenic is organized as either _project-level_ code,
153
+ which refers to customized code for specific projects or baselines or
154
+ _library-level_ code, which refers to common functionalities and general
155
+ patterns that are adapted by the majority of projects. The project-level
156
+ code lives in the `projects` directory.
157
+
158
+ <div align="center">
159
+ <img src="https://raw.githubusercontent.com/google-research/scenic/main/images/scenic_design.jpg" width="900" alt="scenic design"></img>
160
+ </div>
161
+
162
+ ### Library-level code
163
+ The goal is to keep the library-level code minimal and well-tested and to avoid
164
+ introducing extra abstractions to support minor use-cases. Shared libraries
165
+ provided by *Scenic* are split into:
166
+
167
+ * `dataset_lib`: Implements IO pipelines for loading and pre-processing data
168
+ for common Computer Vision tasks and benchmarks (see "Tasks and Datasets"
169
+ section). All pipelines are designed to be scalable and support multi-host
170
+ and multi-device setups, taking care dividing data among multiple hosts,
171
+ incomplete batches, caching, pre-fetching, etc.
172
+ * `model_lib` : Provides
173
+ * several abstract model interfaces (e.g. `ClassificationModel` or
174
+ `SegmentationModel` in `model_lib.base_models`) with task-specific
175
+ losses and metrics;
176
+ * neural network layers in `model_lib.layers`, focusing on efficient
177
+ implementation of attention and transformer layers;
178
+ * accelerator-friendly implementations of bipartite matching
179
+ algorithms in `model_lib.matchers`.
180
+ * `train_lib`: Provides tools for constructing training loops and implements
181
+ several optimized trainers (classification trainer and segmentation trainer)
182
+ that can be forked for customization.
183
+ * `common_lib`: General utilities, like logging and debugging modules,
184
+ functionalities for processing raw data, etc.
185
+
186
+ ### Project-level code
187
+ Scenic supports the development of customized solutions for customized tasks and
188
+ data via the concept of "project". There is no one-fits-all recipe for how much
189
+ code should be re-used by a project. Projects can consist of only configs and
190
+ use the common models, trainers, task/data that live in library-level code, or
191
+ they can simply fork any of the mentioned functionalities and redefine, layers,
192
+ losses, metrics, logging methods, tasks, architectures, as well as training and
193
+ evaluation loops. The modularity of library-level code makes it flexible for
194
+ projects to fall placed on any spot in the "run-as-is" to "fully customized"
195
+ spectrum.
196
+
197
+ Common baselines such as a ResNet and Vision Transformer (ViT) are implemented
198
+ in the [`projects/baselines`](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines)
199
+ project. Forking models in this directory is a good starting point for new
200
+ projects.
201
+
202
+
203
+ ## Citing Scenic
204
+ If you use Scenic, you can cite our [white paper](https://openaccess.thecvf.com/content/CVPR2022/html/Dehghani_Scenic_A_JAX_Library_for_Computer_Vision_Research_and_Beyond_CVPR_2022_paper.html).
205
+ Here is an example BibTeX entry:
206
+
207
+ ```bibtex
208
+ @InProceedings{dehghani2021scenic,
209
+ author = {Dehghani, Mostafa and Gritsenko, Alexey and Arnab, Anurag and Minderer, Matthias and Tay, Yi},
210
+ title = {Scenic: A JAX Library for Computer Vision Research and Beyond},
211
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
212
+ year = {2022},
213
+ pages = {21393-21398}
214
+ }
215
+ ```
216
+
217
+ _Disclaimer: This is not an official Google product._
__pycache__/owlv2_helper.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
__pycache__/owlv2_helper_functions.cpython-310.pyc ADDED
Binary file (7.51 kB). View file
 
auto_bbox.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import cv2
4
+ import json
5
+ import glob
6
+ import argparse
7
+ import subprocess
8
+ from typing import List, Tuple, Dict, Any
9
+
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+
13
+
14
+ # ----------------- Args -----------------
15
+ def parse_args():
16
+ ap = argparse.ArgumentParser("OWLv2 detection on JPG folders (Top-K per image), multi-GPU.")
17
+ ap.add_argument("--input_dir", type=str, required=True, help="Root that contains subfolders of JPGs; if JPGs are directly under input_dir, it will be treated as a single set.")
18
+ ap.add_argument("--startswith", type=str, default="", help="Filter folder name prefix (or input_dir basename if no subfolders).")
19
+ ap.add_argument("--output_dir", type=str, required=True)
20
+ ap.add_argument("--frame_stride", type=int, default=1, help="Sample every N-th image within a folder.")
21
+ ap.add_argument("--top_k", type=int, default=5)
22
+ ap.add_argument("--max_frames", type=int, default=0, help="Max processed images per folder; 0 means no limit.")
23
+ ap.add_argument("--num_workers", type=int, default=1, help="#GPUs/#workers")
24
+ ap.add_argument("--worker_idx", type=int, default=-1, help="internal; >=0 means child worker")
25
+ ap.add_argument("--shard_file", type=str, default="", help="internal; JSON with folder paths for this worker")
26
+ ap.add_argument("--scenic_root", type=str, default="/home/ubuntu/rs/JiT/VisionModels/Scenic_OWLv2/big_vision")
27
+ return ap.parse_args()
28
+
29
+
30
+ # ----------------- Utils -----------------
31
+ def _has_jpgs(path: str) -> bool:
32
+ exts = ("*.jpg", "*.jpeg", "*.JPG", "*.JPEG")
33
+ for pat in exts:
34
+ if glob.glob(os.path.join(path, pat)):
35
+ return True
36
+ return False
37
+
38
+
39
+ def iter_image_dirs(input_dir: str, startswith: str) -> List[str]:
40
+ """
41
+ Returns a list of directories to process.
42
+ - If input_dir contains subfolders: return subfolders that contain JPGs and match startswith.
43
+ - Else if input_dir itself contains JPGs and its basename matches startswith: return [input_dir].
44
+ """
45
+ input_dir = os.path.abspath(input_dir)
46
+ subs = sorted([p for p in glob.glob(os.path.join(input_dir, "*")) if os.path.isdir(p)])
47
+ # Prefer subfolders if any exist and contain jpgs
48
+ dirs = [d for d in subs if os.path.basename(d).startswith(startswith) and _has_jpgs(d)]
49
+ if dirs:
50
+ return dirs
51
+
52
+ # Fallback: treat input_dir itself as one set if it has jpgs
53
+ base_ok = os.path.basename(os.path.normpath(input_dir)).startswith(startswith)
54
+ if base_ok and _has_jpgs(input_dir):
55
+ return [input_dir]
56
+ return []
57
+
58
+
59
+ def ensure_dir(p: str):
60
+ os.makedirs(p, exist_ok=True)
61
+
62
+
63
+ def draw_single_box(frame_bgr: np.ndarray, box: List[float], color=(0, 255, 0), thickness=2) -> np.ndarray:
64
+ x1, y1, x2, y2 = map(int, box)
65
+ out = frame_bgr.copy()
66
+ cv2.rectangle(out, (x1, y1), (x2, y2), color, thickness)
67
+ return out
68
+
69
+
70
+ def list_images_sorted(folder: str) -> List[str]:
71
+ pats = ["*.jpg", "*.jpeg", "*.JPG", "*.JPEG"]
72
+ files = []
73
+ for pat in pats:
74
+ files.extend(glob.glob(os.path.join(folder, pat)))
75
+ # Sort by natural file name order
76
+ return sorted(files)
77
+
78
+
79
+ # ----------------- Worker logic (imports JAX/Scenic inside) -----------------
80
+ def worker_run(args, dir_paths: List[str]):
81
+ import sys as _sys
82
+ if args.scenic_root not in _sys.path:
83
+ _sys.path.append(args.scenic_root)
84
+
85
+ # Free TF GPU to JAX in this process (why: avoid TF reserving VRAM)
86
+ import tensorflow as tf
87
+ tf.config.experimental.set_visible_devices([], "GPU")
88
+
89
+ from scenic.projects.owl_vit import configs
90
+ from scenic.projects.owl_vit import models
91
+ import jax
92
+ import functools
93
+ import owlv2_helper as helper # must be available in PYTHONPATH
94
+
95
+ class OWLv2Objectness:
96
+ def __init__(self, top_k: int = 5):
97
+ self.top_k = top_k
98
+ self.config = configs.owl_v2_clip_b16.get_config(init_mode="canonical_checkpoint")
99
+ self.module = models.TextZeroShotDetectionModule(
100
+ body_configs=self.config.model.body,
101
+ objectness_head_configs=self.config.model.objectness_head,
102
+ normalize=self.config.model.normalize,
103
+ box_bias=self.config.model.box_bias,
104
+ )
105
+ self.variables = self.module.load_variables(self.config.init_from.checkpoint_path)
106
+
107
+ self.image_embedder = jax.jit(
108
+ functools.partial(self.module.apply, self.variables, train=False, method=self.module.image_embedder)
109
+ )
110
+ self.objectness_predictor = jax.jit(
111
+ functools.partial(self.module.apply, self.variables, method=self.module.objectness_predictor)
112
+ )
113
+ self.box_predictor = jax.jit(
114
+ functools.partial(self.module.apply, self.variables, method=self.module.box_predictor)
115
+ )
116
+
117
+ def detect(self, image_bgr: np.ndarray) -> List[Tuple[List[float], float]]:
118
+ image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
119
+ processed = helper.preprocess_images([image_rgb], self.config.dataset_configs.input_size)[0]
120
+ feature_map = self.image_embedder(processed[None, ...])
121
+ b, h, w, d = feature_map.shape
122
+ image_features = feature_map.reshape(b, h * w, d)
123
+
124
+ obj_logits = self.objectness_predictor(image_features)["objectness_logits"]
125
+ raw_boxes = self.box_predictor(image_features=image_features, feature_map=feature_map)["pred_boxes"]
126
+
127
+ obj = np.array(obj_logits[0], dtype=np.float32)
128
+ raw_boxes = np.array(raw_boxes[0], dtype=np.float32)
129
+ boxes = helper.rescale_detection_box(raw_boxes, image_rgb)
130
+
131
+ if len(obj) == 0:
132
+ return []
133
+
134
+ k = min(self.top_k, len(obj))
135
+ thresh = np.partition(obj, -k)[-k]
136
+
137
+ filtered: List[Tuple[List[float], float]] = []
138
+ H, W = image_rgb.shape[:2]
139
+ for box, score in zip(boxes, obj):
140
+ if score < thresh:
141
+ continue
142
+ if helper.too_small(box) or helper.too_large(box, image_rgb):
143
+ continue
144
+ x1, y1, x2, y2 = box
145
+ x1 = max(0, min(float(x1), W - 1))
146
+ y1 = max(0, min(float(y1), H - 1))
147
+ x2 = max(0, min(float(x2), W - 1))
148
+ y2 = max(0, min(float(y2), H - 1))
149
+ filtered.append(([x1, y1, x2, y2], float(score)))
150
+
151
+ kept_boxes = helper.remove_overlapping_bboxes([b for b, _ in filtered])
152
+
153
+ def _match_score(bb: List[float]) -> float:
154
+ arr = np.array([b for b, _ in filtered], dtype=np.float32)
155
+ idx = int(np.argmin(np.abs(arr - np.array(bb, dtype=np.float32)).sum(axis=1)))
156
+ return filtered[idx][1]
157
+
158
+ return [(bb, _match_score(bb)) for bb in kept_boxes]
159
+
160
+ detector = OWLv2Objectness(top_k=args.top_k)
161
+
162
+ for dpath in tqdm(dir_paths, desc=f"Worker{args.worker_idx}", unit="set"):
163
+ stem = os.path.basename(os.path.normpath(dpath))
164
+ images = list_images_sorted(dpath)
165
+ if not images:
166
+ print(f"[WARN][w{args.worker_idx}] No JPGs under: {dpath}")
167
+ continue
168
+
169
+ saved_cnt = 0
170
+ pbar = tqdm(total=len(images), desc=f"{stem}[w{args.worker_idx}]", unit="img", leave=False)
171
+
172
+ for idx, ipath in enumerate(images):
173
+ pbar.update(1)
174
+ if args.frame_stride > 1 and (idx % args.frame_stride) != 0:
175
+ continue
176
+
177
+ frame = cv2.imread(ipath, cv2.IMREAD_COLOR)
178
+ if frame is None:
179
+ print(f"[WARN][w{args.worker_idx}] Cannot read: {ipath}")
180
+ continue
181
+
182
+ boxes_scores = detector.detect(frame)
183
+ if boxes_scores:
184
+ boxes_scores = sorted(boxes_scores, key=lambda x: x[1], reverse=True)[:args.top_k]
185
+
186
+ fname = os.path.basename(ipath)
187
+ for i, (box, score) in enumerate(boxes_scores):
188
+ out_dir = os.path.join(args.output_dir, stem, f"object_{i}")
189
+ ensure_dir(out_dir)
190
+ vis = draw_single_box(frame, box, color=(0, 255, 0), thickness=2)
191
+ cv2.imwrite(os.path.join(out_dir, fname), vis)
192
+
193
+ saved_cnt += 1
194
+ if args.max_frames and saved_cnt >= args.max_frames:
195
+ break
196
+
197
+ pbar.close()
198
+
199
+
200
+ # ----------------- Master -----------------
201
+ def main():
202
+ args = parse_args()
203
+
204
+ # Child worker path
205
+ if args.worker_idx >= 0:
206
+ if not args.shard_file or not os.path.exists(args.shard_file):
207
+ raise RuntimeError("Worker requires --shard_file with JSON list of folder paths.")
208
+ with open(args.shard_file, "r", encoding="utf-8") as f:
209
+ dir_paths = json.load(f)
210
+ worker_run(args, dir_paths)
211
+ return
212
+
213
+ # Master path
214
+ dir_paths = iter_image_dirs(args.input_dir, args.startswith)
215
+ if not dir_paths:
216
+ print(f"[INFO] No JPG folders (or JPGs) startwith '{args.startswith}' under {args.input_dir}")
217
+ return
218
+
219
+ num_workers = max(1, int(args.num_workers))
220
+ shards: List[List[str]] = [[] for _ in range(num_workers)]
221
+ for i, d in enumerate(dir_paths):
222
+ shards[i % num_workers].append(d)
223
+
224
+ procs = []
225
+ tmp_dir = os.path.join(args.output_dir, "_shards_tmp")
226
+ ensure_dir(tmp_dir)
227
+
228
+ for w in range(num_workers):
229
+ shard_path = os.path.join(tmp_dir, f"shard_{w}.json")
230
+ with open(shard_path, "w", encoding="utf-8") as f:
231
+ json.dump(shards[w], f, ensure_ascii=False, indent=0)
232
+
233
+ # Bind GPU: cycle through available GPU ids [0..num_workers-1]
234
+ env = os.environ.copy()
235
+ env["CUDA_VISIBLE_DEVICES"] = str(w) # one GPU per worker
236
+
237
+ cmd = [
238
+ sys.executable, __file__,
239
+ "--input_dir", args.input_dir,
240
+ "--startswith", args.startswith,
241
+ "--output_dir", args.output_dir,
242
+ "--frame_stride", str(args.frame_stride),
243
+ "--top_k", str(args.top_k),
244
+ "--max_frames", str(args.max_frames),
245
+ "--num_workers", str(num_workers),
246
+ "--worker_idx", str(w),
247
+ "--shard_file", shard_path,
248
+ "--scenic_root", args.scenic_root,
249
+ ]
250
+ print(f"[Master] Launch worker {w}: GPU={env['CUDA_VISIBLE_DEVICES']} folders={len(shards[w])}")
251
+ procs.append(subprocess.Popen(cmd, env=env))
252
+
253
+ # wait
254
+ rc = 0
255
+ for p in procs:
256
+ p.wait()
257
+ rc |= p.returncode
258
+
259
+ if rc != 0:
260
+ print("[Master] Some workers failed. Return code:", rc)
261
+ else:
262
+ print("[Master] All workers done. Output:", args.output_dir)
263
+
264
+
265
+ if __name__ == "__main__":
266
+ main()
big_vision/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
big_vision/CONTRIBUTING.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Contribute
2
+
3
+ At this time we do not plan to accept non-trivial contributions. The main
4
+ purpose of this codebase is to allow the community to reproduce results from our
5
+ publications.
6
+
7
+ You are however free to start a fork of the project for your purposes as
8
+ permitted by the license.
9
+
10
+ ## Contributor License Agreement
11
+
12
+ Contributions to this project must be accompanied by a Contributor License
13
+ Agreement (CLA). You (or your employer) retain the copyright to your
14
+ contribution; this simply gives us permission to use and redistribute your
15
+ contributions as part of the project. Head over to
16
+ <https://cla.developers.google.com/> to see your current agreements on file or
17
+ to sign a new one.
18
+
19
+ You generally only need to submit a CLA once, so if you've already submitted one
20
+ (even if it was for a different project), you probably don't need to do it
21
+ again.
22
+
23
+ ## Community Guidelines
24
+
25
+ This project follows
26
+ [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
big_vision/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
big_vision/README.md ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Big Vision
2
+
3
+ This codebase is designed for training large-scale vision models using
4
+ [Cloud TPU VMs](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms)
5
+ or GPU machines. It is based on [Jax](https://github.com/google/jax)/[Flax](https://github.com/google/flax)
6
+ libraries, and uses [tf.data](https://www.tensorflow.org/guide/data) and
7
+ [TensorFlow Datasets](https://www.tensorflow.org/datasets) for scalable and
8
+ reproducible input pipelines.
9
+
10
+ The open-sourcing of this codebase has two main purposes:
11
+ 1. Publishing the code of research projects developed in this codebase (see a
12
+ list below).
13
+ 2. Providing a strong starting point for running large-scale vision experiments
14
+ on GPU machines and Google Cloud TPUs, which should scale seamlessly and
15
+ out-of-the box from a single TPU core to a distributed setup with up to 2048
16
+ TPU cores.
17
+
18
+ `big_vision` aims to support research projects at Google. We are unlikely to
19
+ work on feature requests or accept external contributions, unless they were
20
+ pre-approved (ask in an issue first). For a well-supported transfer-only
21
+ codebase, see also [vision_transformer](https://github.com/google-research/vision_transformer).
22
+
23
+ Note that `big_vision` is quite dynamic codebase and, while we intend to keep
24
+ the core code fully-functional at all times, we can not guarantee timely updates
25
+ of the project-specific code that lives in the `.../proj/...` subfolders.
26
+ However, we provide a [table](#project-specific-commits) with last known
27
+ commits where specific projects were known to work.
28
+
29
+ The following research projects were originally conducted in the `big_vision`
30
+ codebase:
31
+
32
+ ### Architecture research
33
+
34
+ - [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929), by
35
+ Alexey Dosovitskiy*, Lucas Beyer*, Alexander Kolesnikov*, Dirk Weissenborn*,
36
+ Xiaohua Zhai*, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer,
37
+ Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby*
38
+ - [Scaling Vision Transformers](https://arxiv.org/abs/2106.04560), by
39
+ Xiaohua Zhai*, Alexander Kolesnikov*, Neil Houlsby, and Lucas Beyer*\
40
+ Resources: [config](big_vision/configs/proj/scaling_laws/train_vit_g.py).
41
+ - [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270), by
42
+ Andreas Steiner*, Alexander Kolesnikov*, Xiaohua Zhai*, Ross Wightman,
43
+ Jakob Uszkoreit, and Lucas Beyer*
44
+ - [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601), by
45
+ Ilya Tolstikhin*, Neil Houlsby*, Alexander Kolesnikov*, Lucas Beyer*,
46
+ Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner,
47
+ Daniel Keysers, Jakob Uszkoreit, Mario Lucic, Alexey Dosovitskiy\
48
+ Resources: [config](big_vision/configs/mlp_mixer_i1k.py).
49
+ - [Better plain ViT baselines for ImageNet-1k](https://arxiv.org/abs/2205.01580), by
50
+ Lucas Beyer, Xiaohua Zhai, Alexander Kolesnikov\
51
+ Resources: [config](big_vision/configs/vit_s16_i1k.py)
52
+ - [UViM: A Unified Modeling Approach for Vision with Learned Guiding Codes](https://arxiv.org/abs/2205.10337), by
53
+ Alexander Kolesnikov^*, André Susano Pinto^*, Lucas Beyer*, Xiaohua Zhai*, Jeremiah Harmsen*, Neil Houlsby*\
54
+ Resources: [readme](big_vision/configs/proj/uvim/README.md), [configs](big_vision/configs/proj/uvim), [colabs](big_vision/configs/proj/uvim).
55
+ - [FlexiViT: One Model for All Patch Sizes](https://arxiv.org/abs/2212.08013), by
56
+ Lucas Beyer*, Pavel Izmailov*, Alexander Kolesnikov*, Mathilde Caron*, Simon
57
+ Kornblith*, Xiaohua Zhai*, Matthias Minderer*, Michael Tschannen*, Ibrahim
58
+ Alabdulmohsin*, Filip Pavetic*\
59
+ Resources: [readme](big_vision/configs/proj/flexivit/README.md), [configs](big_vision/configs/proj/flexivit).
60
+ - [Dual PatchNorm](https://arxiv.org/abs/2302.01327), by Manoj Kumar, Mostafa Dehghani, Neil Houlsby.
61
+ - [Getting ViT in Shape: Scaling Laws for Compute-Optimal Model Design](https://arxiv.org/abs/2305.13035), by
62
+ Ibrahim Alabdulmohsin*, Xiaohua Zhai*, Alexander Kolesnikov, Lucas Beyer*.
63
+ - (partial) [Scaling Vision Transformers to 22 Billion Parameters](https://arxiv.org/abs/2302.05442), by
64
+ Mostafa Dehghani*, Josip Djolonga*, Basil Mustafa*, Piotr Padlewski*, Jonathan Heek*, *wow many middle authors*, Neil Houlsby*.
65
+ - (partial) [Finite Scalar Quantization: VQ-VAE Made Simple](https://arxiv.org/abs/2309.15505), by
66
+ Fabian Mentzer, David Minnen, Eirikur Agustsson, Michael Tschannen.
67
+ - [GIVT: Generative Infinite-Vocabulary Transformers](https://arxiv.org/abs/2312.02116), by
68
+ Michael Tschannen, Cian Eastwood, Fabian Mentzer.\
69
+ Resources: [readme](big_vision/configs/proj/givt/README.md), [config](big_vision/configs/proj/givt/givt_imagenet2012.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/givt/givt_demo_colab.ipynb).
70
+ - [Unified Auto-Encoding with Masked Diffusion](https://arxiv.org/abs/2406.17688), by
71
+ Philippe Hansen-Estruch, Sriram Vishwanath, Amy Zhang, Manan Tomar.
72
+
73
+
74
+ ### Multimodal research
75
+
76
+ - [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991), by
77
+ Xiaohua Zhai*, Xiao Wang*, Basil Mustafa*, Andreas Steiner*, Daniel Keysers,
78
+ Alexander Kolesnikov, and Lucas Beyer*\
79
+ Resources: [trainer](big_vision/trainers/proj/image_text/contrastive.py), [config](big_vision/configs/proj/image_text/lit_coco.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/lit.ipynb).
80
+ - [Image-and-Language Understanding from Pixels Only](https://arxiv.org/abs/2212.08045), by
81
+ Michael Tschannen, Basil Mustafa, Neil Houlsby\
82
+ Resources: [readme](big_vision/configs/proj/clippo/README.md), [config](big_vision/configs/proj/clippo/train_clippo.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/clippo/clippo_colab.ipynb).
83
+ - [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343), by
84
+ Xiaohua Zhai*, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer*\
85
+ Resources: [colab and models](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb), code TODO.
86
+ - [A Study of Autoregressive Decoders for Multi-Tasking in Computer Vision](https://arxiv.org/abs/2303.17376), by
87
+ Lucas Beyer*, Bo Wan*, Gagan Madan*, Filip Pavetic*, Andreas Steiner*, Alexander Kolesnikov, André Susano Pinto, Emanuele Bugliarello, Xiao Wang, Qihang Yu, Liang-Chieh Chen, Xiaohua Zhai*.
88
+ - [Image Captioners Are Scalable Vision Learners Too](https://arxiv.org/abs/2306.07915), by
89
+ Michael Tschannen*, Manoj Kumar*, Andreas Steiner*, Xiaohua Zhai, Neil Houlsby, Lucas Beyer*.\
90
+ Resources: [readme](big_vision/configs/proj/cappa/README.md), [config](big_vision/configs/proj/cappa/pretrain.py), [model](big_vision/models/proj/cappa/cappa.py).
91
+ - [Three Towers: Flexible Contrastive Learning with Pretrained Image Models](https://arxiv.org/abs/2305.16999), by Jannik Kossen, Mark Collier, Basil Mustafa, Xiao Wang, Xiaohua Zhai, Lucas Beyer, Andreas Steiner, Jesse Berent, Rodolphe Jenatton, Efi Kokiopoulou.
92
+ - (partial) [PaLI: A Jointly-Scaled Multilingual Language-Image Model](https://arxiv.org/abs/2209.06794), by Xi Chen, Xiao Wang, Soravit Changpinyo, *wow so many middle authors*, Anelia Angelova, Xiaohua Zhai, Neil Houlsby, Radu Soricut.
93
+ - (partial) [PaLI-3 Vision Language Models: Smaller, Faster, Stronger](https://arxiv.org/abs/2310.09199), by Xi Chen, Xiao Wang, Lucas Beyer, Alexander Kolesnikov, Jialin Wu, Paul Voigtlaender, Basil Mustafa, Sebastian Goodman, Ibrahim Alabdulmohsin, Piotr Padlewski, Daniel Salz, Xi Xiong, Daniel Vlasic, Filip Pavetic, Keran Rong, Tianli Yu, Daniel Keysers, Xiaohua Zhai, Radu Soricut.
94
+ - [LocCa](https://arxiv.org/abs/2403.19596), by
95
+ Bo Wan, Michael Tschannen, Yongqin Xian, Filip Pavetic, Ibrahim Alabdulmohsin, Xiao Wang, André Susano Pinto, Andreas Steiner, Lucas Beyer, Xiaohua Zhai.
96
+ - [PaliGemma](https://arxiv.org/abs/2407.07726),
97
+ [PaliGemma 2](https://arxiv.org/abs/2412.03555), by *wow many authors*.\
98
+ - Resources: [readme](big_vision/configs/proj/paligemma/README.md),
99
+ [model](big_vision/models/proj/paligemma/paligemma.py),
100
+ [transfer configs](big_vision/configs/proj/paligemma/transfers),
101
+ [datasets](big_vision/datasets),
102
+ [CountBenchQA](big_vision/datasets/countbenchqa/data/countbench_paired_questions.json).
103
+
104
+ ### Training
105
+
106
+ - [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237), by
107
+ Lucas Beyer*, Xiaohua Zhai*, Amélie Royer*, Larisa Markeeva*, Rohan Anil,
108
+ and Alexander Kolesnikov*\
109
+ Resources: [README](big_vision/configs/proj/distill/README.md), [trainer](big_vision/trainers/proj/distill/distill.py), [colab](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing).
110
+ - [Sharpness-Aware Minimization for Efficiently Improving Generalization](https://arxiv.org/abs/2010.01412), by
111
+ Pierre Foret, Ariel Kleiner, Hossein Mobahi, Behnam Neyshabur
112
+ - [Surrogate Gap Minimization Improves Sharpness-Aware Training](https://arxiv.org/abs/2203.08065), by Juntang Zhuang, Boqing Gong, Liangzhe Yuan, Yin Cui, Hartwig Adam, Nicha Dvornek, Sekhar Tatikonda, James Duncan and Ting Liu \
113
+ Resources: [trainer](big_vision/trainers/proj/gsam/gsam.py), [config](big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py) [reproduced results](https://github.com/google-research/big_vision/pull/8#pullrequestreview-1078557411)
114
+ - [Tuning computer vision models with task rewards](https://arxiv.org/abs/2302.08242), by
115
+ André Susano Pinto*, Alexander Kolesnikov*, Yuge Shi, Lucas Beyer, Xiaohua Zhai.
116
+ - (partial) [VeLO: Training Versatile Learned Optimizers by Scaling Up](https://arxiv.org/abs/2211.09760) by
117
+ Luke Metz, James Harrison, C. Daniel Freeman, Amil Merchant, Lucas Beyer, James Bradbury, Naman Agrawal, Ben Poole, Igor Mordatch, Adam Roberts, Jascha Sohl-Dickstein.
118
+
119
+ ### Misc
120
+
121
+ - [Are we done with ImageNet?](https://arxiv.org/abs/2006.07159), by
122
+ Lucas Beyer*, Olivier J. Hénaff*, Alexander Kolesnikov*, Xiaohua Zhai*, Aäron van den Oord*.
123
+ - [No Filter: Cultural and Socioeconomic Diversity in Contrastive Vision-Language Models](https://arxiv.org/abs/2405.13777), by
124
+ Angéline Pouget, Lucas Beyer, Emanuele Bugliarello, Xiao Wang, Andreas Peter Steiner, Xiaohua Zhai, Ibrahim Alabdulmohsin.
125
+
126
+ # Codebase high-level organization and principles in a nutshell
127
+
128
+ The main entry point is a trainer module, which typically does all the
129
+ boilerplate related to creating a model and an optimizer, loading the data,
130
+ checkpointing and training/evaluating the model inside a loop. We provide the
131
+ canonical trainer `train.py` in the root folder. Normally, individual projects
132
+ within `big_vision` fork and customize this trainer.
133
+
134
+ All models, evaluators and preprocessing operations live in the corresponding
135
+ subdirectories and can often be reused between different projects. We encourage
136
+ compatible APIs within these directories to facilitate reusability, but it is
137
+ not strictly enforced, as individual projects may need to introduce their custom
138
+ APIs.
139
+
140
+ We have a powerful configuration system, with the configs living in the
141
+ `configs/` directory. Custom trainers and modules can directly extend/modify
142
+ the configuration options.
143
+
144
+ Project-specific code resides in the `.../proj/...` namespace. It is not always
145
+ possible to keep project-specific in sync with the core `big_vision` libraries,
146
+ Below we provide the [last known commit](#project-specific-commits)
147
+ for each project where the project code is expected to work.
148
+
149
+ Training jobs are robust to interruptions and will resume seamlessly from the
150
+ last saved checkpoint (assuming a user provides the correct `--workdir` path).
151
+
152
+ Each configuration file contains a comment at the top with a `COMMAND` snippet
153
+ to run it, and some hint of expected runtime and results. See below for more
154
+ details, but generally speaking, running on a GPU machine involves calling
155
+ `python -m COMMAND` while running on TPUs, including multi-host, involves
156
+
157
+ ```
158
+ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all
159
+ --command "bash big_vision/run_tpu.sh COMMAND"
160
+ ```
161
+
162
+ See instructions below for more details on how to run `big_vision` code on a
163
+ GPU machine or Google Cloud TPU.
164
+
165
+ By default we write checkpoints and logfiles. The logfiles are a list of JSON
166
+ objects, and we provide a short and straightforward [example colab to read
167
+ and display the logs and checkpoints](https://colab.research.google.com/drive/1R_lvV542WUp8Q2y8sbyooZOGCplkn7KI?usp=sharing).
168
+
169
+ # Current and future contents
170
+
171
+ The first release contains the core part of pre-training, transferring, and
172
+ evaluating classification models at scale on Cloud TPU VMs.
173
+
174
+ We have since added the following key features and projects:
175
+ - Contrastive Image-Text model training and evaluation as in LiT and CLIP.
176
+ - Patient and consistent distillation.
177
+ - Scaling ViT.
178
+ - MLP-Mixer.
179
+ - UViM.
180
+
181
+ Features and projects we plan to release in the near future, in no particular
182
+ order:
183
+ - ImageNet-21k in TFDS.
184
+ - Loading misc public models used in our publications (NFNet, MoCov3, DINO).
185
+ - Memory-efficient Polyak-averaging implementation.
186
+ - Advanced JAX compute and memory profiling. We are using internal tools for
187
+ this, but may eventually add support for the publicly available ones.
188
+
189
+ We will continue releasing code of our future publications developed within
190
+ `big_vision` here.
191
+
192
+ ### Non-content
193
+
194
+ The following exist in the internal variant of this codebase, and there is no
195
+ plan for their release:
196
+ - Regular regression tests for both quality and speed. They rely heavily on
197
+ internal infrastructure.
198
+ - Advanced logging, monitoring, and plotting of experiments. This also relies
199
+ heavily on internal infrastructure. However, we are open to ideas on this
200
+ and may add some in the future, especially if implemented in a
201
+ self-contained manner.
202
+ - Not yet published, ongoing research projects.
203
+
204
+
205
+ # GPU Setup
206
+
207
+ We first discuss how to setup and run `big_vision` on a (local) GPU machine,
208
+ and then discuss the setup for Cloud TPUs. Note that data preparation step for
209
+ (local) GPU setup can be largely reused for the Cloud TPU setup. While the
210
+ instructions skip this for brevity, we highly recommend using a
211
+ [virtual environment](https://docs.python.org/3/library/venv.html) when
212
+ installing python dependencies.
213
+
214
+ ## Setting up python packages
215
+
216
+ The first step is to checkout `big_vision` and install relevant python
217
+ dependencies:
218
+
219
+ ```
220
+ git clone https://github.com/google-research/big_vision
221
+ cd big_vision/
222
+ pip3 install --upgrade pip
223
+ pip3 install -r big_vision/requirements.txt
224
+ ```
225
+
226
+ The latest version of `jax` library can be fetched as
227
+
228
+ ```
229
+ pip3 install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
230
+ ```
231
+
232
+ You may need a different `jax` package, depending on CUDA and cuDNN libraries
233
+ installed on your machine. Please consult
234
+ [official jax documentation](https://github.com/google/jax#pip-installation-gpu-cuda)
235
+ for more information.
236
+
237
+ ## Preparing tfds data
238
+
239
+ For unified and reproducible access to standard datasets we opted to use the
240
+ `tensorflow_datasets` (`tfds`) library. It requires each dataset to be
241
+ downloaded, preprocessed and then to be stored on a hard drive (or, if you use
242
+ "Google Cloud", preferably stored in a "GCP bucket".).
243
+
244
+ Many datasets can be downloaded and preprocessed automatically when used
245
+ for the first time. Nevertheless, we intentionally disable this feature and
246
+ recommend doing dataset preparation step separately, ahead of the first run. It
247
+ will make debugging easier if problems arise and some datasets, like
248
+ `imagenet2012`, require manually downloaded data.
249
+
250
+ Most of the datasets, e.g. `cifar100`, `oxford_iiit_pet` or `imagenet_v2`
251
+ can be fully automatically downloaded and prepared by running
252
+
253
+ ```
254
+ cd big_vision/
255
+ python3 -m big_vision.tools.download_tfds_datasets cifar100 oxford_iiit_pet imagenet_v2
256
+ ```
257
+
258
+ A full list of datasets is available at [this link](https://www.tensorflow.org/datasets/catalog/overview#all_datasets).
259
+
260
+ Some datasets, like `imagenet2012` or `imagenet2012_real`, require the data to
261
+ be downloaded manually and placed into `$TFDS_DATA_DIR/downloads/manual/`,
262
+ which defaults to `~/tensorflow_datasets/downloads/manual/`. For example, for
263
+ `imagenet2012` and `imagenet2012_real` one needs to place the official
264
+ `ILSVRC2012_img_train.tar` and `ILSVRC2012_img_val.tar` files in that directory
265
+ and then run
266
+ `python3 -m big_vision.tools.download_tfds_datasets imagenet2012 imagenet2012_real`
267
+ (which may take ~1 hour).
268
+
269
+ If you use `Google Cloud` and, TPUs in particular, you can then upload
270
+ the preprocessed data (stored in `$TFDS_DATA_DIR`) to
271
+ "Google Cloud Bucket" and use the bucket on any of your (TPU) virtual
272
+ machines to access the data.
273
+
274
+ ## Running on a GPU machine
275
+
276
+ Finally, after installing all python dependencies and preparing `tfds` data,
277
+ the user can run the job using config of their choice, e.g. to train `ViT-S/16`
278
+ model on ImageNet data, one should run the following command:
279
+
280
+ ```
281
+ python3 -m big_vision.train --config big_vision/configs/vit_s16_i1k.py --workdir workdirs/`date '+%m-%d_%H%M'`
282
+ ```
283
+
284
+ or to train MLP-Mixer-B/16, run (note the `gpu8` config param that reduces the default batch size and epoch count):
285
+
286
+ ```
287
+ python3 -m big_vision.train --config big_vision/configs/mlp_mixer_i1k.py:gpu8 --workdir workdirs/`date '+%m-%d_%H%M'`
288
+ ```
289
+
290
+ # Cloud TPU VM setup
291
+
292
+ ## Create TPU VMs
293
+
294
+ To create a single machine with 8 TPU cores, follow the following Cloud TPU JAX
295
+ document:
296
+ https://cloud.google.com/tpu/docs/run-calculation-jax
297
+
298
+ To support large-scale vision research, more cores with multiple hosts are
299
+ recommended. Below we provide instructions on how to do it.
300
+
301
+ First, create some useful variables, which we be reused:
302
+
303
+ ```
304
+ export NAME=<a name of the TPU deployment, e.g. my-tpu-machine>
305
+ export ZONE=<GCP geographical zone, e.g. europe-west4-a>
306
+ export GS_BUCKET_NAME=<Name of the storage bucket, e.g. my_bucket>
307
+ ```
308
+
309
+ The following command line will create TPU VMs with 32 cores,
310
+ 4 hosts.
311
+
312
+ ```
313
+ gcloud compute tpus tpu-vm create $NAME --zone $ZONE --accelerator-type v3-32 --version tpu-ubuntu2204-base
314
+ ```
315
+
316
+ ## Install `big_vision` on TPU VMs
317
+
318
+ Fetch the `big_vision` repository, copy it to all TPU VM hosts, and install
319
+ dependencies.
320
+
321
+ ```
322
+ git clone https://github.com/google-research/big_vision
323
+ gcloud compute tpus tpu-vm scp --recurse big_vision/big_vision $NAME: --zone=$ZONE --worker=all
324
+ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "bash big_vision/run_tpu.sh"
325
+ ```
326
+
327
+ ## Download and prepare TFDS datasets
328
+
329
+ We recommend preparing `tfds` data locally as described above and then uploading
330
+ the data to `Google Cloud` bucket. However, if you prefer, the datasets which
331
+ do not require manual downloads can be prepared automatically using a TPU
332
+ machine as described below. Note that TPU machines have only 100 GB of disk
333
+ space, and multihost TPU slices do not allow for external disks to be attached
334
+ in a write mode, so the instructions below may not work for preparing large
335
+ datasets. As yet another alternative, we provide instructions
336
+ [on how to prepare `tfds` data on CPU-only GCP machine](#preparing-tfds-data-on-a-standalone-gcp-cpu-machine).
337
+
338
+ Specifically, the seven TFDS datasets used during evaluations will be generated
339
+ under `~/tensorflow_datasets` on TPU machine with this command:
340
+
341
+ ```
342
+ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "TFDS_DATA_DIR=~/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets cifar10 cifar100 oxford_iiit_pet oxford_flowers102 cars196 dtd uc_merced"
343
+ ```
344
+
345
+ You can then copy the datasets to GS bucket, to make them accessible to all TPU workers.
346
+
347
+ ```
348
+ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "rm -r ~/tensorflow_datasets/downloads && gsutil cp -r ~/tensorflow_datasets gs://$GS_BUCKET_NAME"
349
+ ```
350
+
351
+ If you want to integrate other public or custom datasets, i.e. imagenet2012,
352
+ please follow [the official guideline](https://www.tensorflow.org/datasets/catalog/overview).
353
+
354
+ ## Pre-trained models
355
+
356
+ For the full list of pre-trained models check out the `load` function defined in
357
+ the same module as the model code. And for example config on how to use these
358
+ models, see `configs/transfer.py`.
359
+
360
+ ## Run the transfer script on TPU VMs
361
+
362
+ The following command line fine-tunes a pre-trained `vit-i21k-augreg-b/32` model
363
+ on `cifar10` dataset.
364
+
365
+ ```
366
+ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03"
367
+ ```
368
+
369
+ ## Run the train script on TPU VMs
370
+
371
+ To train your own big_vision models on a large dataset,
372
+ e.g. `imagenet2012` ([prepare the TFDS dataset](https://www.tensorflow.org/datasets/catalog/imagenet2012)),
373
+ run the following command line.
374
+
375
+ ```
376
+ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/bit_i1k.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`"
377
+ ```
378
+
379
+ ## FSDP training.
380
+
381
+ `big_vision` supports flexible parameter and model sharding strategies.
382
+ Currently, we support a popular FSDP sharding via a simple config change, see [this config example](big_vision/configs/transfer.py).
383
+ For example, to run FSDP finetuning of a pretrained ViT-L model, run the following command (possible adjusting batch size depending on your hardware):
384
+
385
+ ```
386
+ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-l/16,dataset=oxford_iiit_pet,crop=resmall_crop,fsdp=True,batch_size=256 --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03"
387
+ ```
388
+
389
+ ## Image-text training with SigLIP.
390
+
391
+ A minimal example that uses public `coco` captions data:
392
+
393
+ ```
394
+ gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.trainers.proj.image_text.siglip --config big_vision/configs/proj/image_text/siglip_lit_coco.py --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%Y-%m-%d_%H%M'`"
395
+ ```
396
+
397
+
398
+
399
+ ## Sometimes useful gcloud commands
400
+
401
+ - Destroy the TPU machines: `gcloud compute tpus tpu-vm delete $NAME --zone $ZONE`
402
+ - Remove all big_vision-related folders on all hosts: `gcloud compute tpus tpu-vm ssh $NAME --zone $ZONE --worker=all --command 'rm -rf ~/big_vision ~/bv_venv'`
403
+
404
+ ## Preparing `tfds` data on a standalone GCP CPU machine.
405
+
406
+ First create a new machine and a disk (feel free to adjust exact machine type and disk settings/capacity):
407
+
408
+ ```
409
+ export NAME_CPU_HOST=<A name of a CPU-only machine>
410
+ export NAME_DISK=<A name of a disk>
411
+ gcloud compute instances create $NAME_CPU_HOST --machine-type c3-standard-22 --zone $ZONE --image-family ubuntu-2204-lts --image-project ubuntu-os-cloud
412
+ gcloud compute disks create $NAME_DISK --size 1000GB --zone $ZONE --type pd-balanced
413
+ ```
414
+
415
+ Now attach the disk to the newly create machine:
416
+
417
+ ```
418
+ gcloud compute instances attach-disk $NAME_CPU_HOST --disk $NAME_DISK --zone $ZONE
419
+ ```
420
+
421
+ Next, `ssh` to the machine `gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE` and
422
+ [follow instructions to format and mount the disk](https://cloud.google.com/compute/docs/disks/format-mount-disk-linux).
423
+ Let's assume it was mounted to `/mnt/disks/tfds`.
424
+
425
+ Almost there, now clone and set up `big_vision`:
426
+
427
+ ```
428
+ gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "git clone https://github.com/google-research/big_vision.git && cd big_vision && sh big_vision/run_tpu.sh"
429
+ ```
430
+
431
+ Finally, prepare the dataset (e.g. `coco_captions`) using the utility script and
432
+ copy the result to you google cloud bucket:
433
+
434
+ ```
435
+ gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "cd big_vision && TFDS_DATA_DIR=/mnt/disks/tfds/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets coco_captions"
436
+ gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "rm -rf /mnt/disks/tfds/tensorflow_datasets/downloads && gsutil cp -r /mnt/disks/tfds/tensorflow_datasets gs://$GS_BUCKET_NAME"
437
+ ```
438
+
439
+
440
+ # ViT baseline
441
+
442
+ We provide a well-tuned ViT-S/16 baseline in the config file named
443
+ `vit_s16_i1k.py`. It achieves 76.5% accuracy on ImageNet validation split in
444
+ 90 epochs of training, being a strong and simple starting point for research
445
+ on the ViT models.
446
+
447
+ Please see our [arXiv note](https://arxiv.org/abs/2205.01580) for more details
448
+ and if this baseline happens to by useful for your research, consider citing
449
+
450
+ ```
451
+ @article{vit_baseline,
452
+ url = {https://arxiv.org/abs/2205.01580},
453
+ author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
454
+ title = {Better plain ViT baselines for ImageNet-1k},
455
+ journal={arXiv preprint arXiv:2205.01580},
456
+ year = {2022},
457
+ }
458
+ ```
459
+
460
+ # Project specific commits
461
+
462
+ The last known commit where the specific project code is expected to work. The
463
+ core code and configs are expected to work at head.
464
+
465
+ | Project | Commit |
466
+ |------------|-----------------------------------------------------------------------------------------------|
467
+ | UViM | https://github.com/google-research/big_vision/commit/21bd6ebe253f070f584d8b777ad76f4abce51bef |
468
+ | image_text | https://github.com/google-research/big_vision/commit/8921d5141504390a8a4f7b2dacb3b3c042237290 |
469
+ | distill | https://github.com/google-research/big_vision/commit/2f3f493af048dbfd97555ff6060f31a0e686f17f |
470
+ | GSAM | WIP |
471
+ | CLIPPO | https://github.com/google-research/big_vision/commit/fd2d3bd2efc9d89ea959f16cd2f58ae8a495cd44 |
472
+ | CapPa | https://github.com/google-research/big_vision/commit/7ace659452dee4b68547575352c022a2eef587a5 |
473
+ | GIVT | https://github.com/google-research/big_vision/commit/0cb70881dd33b3343b769347dc19793c4994b8cb |
474
+
475
+ # Citing the codebase
476
+
477
+ If you found this codebase useful for your research, please consider using
478
+ the following BibTEX to cite it:
479
+
480
+ ```
481
+ @misc{big_vision,
482
+ author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
483
+ title = {Big Vision},
484
+ year = {2022},
485
+ publisher = {GitHub},
486
+ journal = {GitHub repository},
487
+ howpublished = {\url{https://github.com/google-research/big_vision}}
488
+ }
489
+ ```
490
+
491
+ # Disclaimer
492
+
493
+ This is not an official Google Product.
494
+
495
+ # License
496
+
497
+ Unless explicitly noted otherwise, everything in the big_vision codebase
498
+ (including models and colabs) is released under the Apache2 license.
499
+ See the LICENSE file for the full license text.
big_vision/__init__.py ADDED
File without changes
big_vision/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (179 Bytes). View file
 
big_vision/__pycache__/utils.cpython-310.pyc ADDED
Binary file (52.6 kB). View file
 
big_vision/configs/__init__.py ADDED
File without changes
big_vision/configs/bit_i1k.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Pre-training BiT on ILSVRC-2012 as in https://arxiv.org/abs/1912.11370
17
+
18
+ Run training of a BiT-ResNet-50x1 variant, which takes ~32min on v3-128:
19
+
20
+ big_vision.train \
21
+ --config big_vision/configs/bit_i1k.py \
22
+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
23
+ --config.model.depth 50 --config.model.width 1
24
+ """
25
+
26
+ # from big_vision.configs.common_fewshot import get_fewshot_lsr
27
+ import ml_collections as mlc
28
+
29
+
30
+ def get_config(runlocal=False):
31
+ """Config for training on ImageNet-1k."""
32
+ config = mlc.ConfigDict()
33
+
34
+ config.seed = 0
35
+ config.total_epochs = 90
36
+ config.num_classes = 1000
37
+ config.loss = 'softmax_xent'
38
+
39
+ config.input = dict()
40
+ config.input.data = dict(
41
+ name='imagenet2012',
42
+ split='train[:99%]',
43
+ )
44
+ config.input.batch_size = 4096
45
+ config.input.cache_raw = True # Needs up to 120GB of RAM!
46
+ config.input.shuffle_buffer_size = 250_000 # Per host.
47
+
48
+ pp_common = '|onehot(1000, key="{lbl}", key_result="labels")'
49
+ pp_common += '|value_range(-1, 1)|keep("image", "labels")'
50
+ config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label')
51
+ pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
52
+
53
+ config.log_training_steps = 50
54
+ config.ckpt_steps = 1000
55
+
56
+ # Model section
57
+ config.model_name = 'bit'
58
+ config.model = dict(
59
+ depth=50, # You can also pass e.g. [3, 5, 10, 2]
60
+ width=1.0,
61
+ )
62
+
63
+ # Optimizer section
64
+ config.optax_name = 'big_vision.momentum_hp'
65
+ config.grad_clip_norm = 1.0
66
+
67
+ # linear scaling rule. Don't forget to sweep if sweeping batch_size.
68
+ config.wd = (1e-4 / 256) * config.input.batch_size
69
+ config.lr = (0.1 / 256) * config.input.batch_size
70
+ config.schedule = dict(decay_type='cosine', warmup_steps=1000)
71
+
72
+ # Eval section
73
+ def get_eval(split, dataset='imagenet2012'):
74
+ return dict(
75
+ type='classification',
76
+ data=dict(name=dataset, split=split),
77
+ pp_fn=pp_eval.format(lbl='label'),
78
+ loss_name=config.loss,
79
+ log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
80
+ cache='final_data',
81
+ )
82
+ config.evals = {}
83
+ config.evals.train = get_eval('train[:2%]')
84
+ config.evals.minival = get_eval('train[99%:]')
85
+ config.evals.val = get_eval('validation')
86
+ config.evals.v2 = get_eval('test', dataset='imagenet_v2')
87
+ config.evals.real = get_eval('validation', dataset='imagenet2012_real')
88
+ config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
89
+
90
+ # config.evals.fewshot = get_fewshot_lsr(runlocal=runlocal)
91
+ # config.evals.fewshot.log_steps = 1000
92
+
93
+ if runlocal:
94
+ config.input.batch_size = 32
95
+ config.input.cache_raw = False
96
+ config.input.shuffle_buffer_size = 100
97
+
98
+ local_eval = config.evals.val
99
+ config.evals = {'val': local_eval}
100
+ config.evals.val.cache = 'none'
101
+
102
+ return config
big_vision/configs/bit_i21k.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""A config for pre-training BiT on ImageNet-21k.
17
+
18
+ This config relies on the Imagenet-21k tfds dataset, which is not yet
19
+ available publicly in TFDS. We intend to add the dataset to public TFDS soon,
20
+ and this config will then be runnable.
21
+ """
22
+
23
+ from big_vision.configs.common_fewshot import get_fewshot_lsr
24
+ import ml_collections as mlc
25
+
26
+
27
+ def get_config():
28
+ """Config for training on imagenet-21k."""
29
+ config = mlc.ConfigDict()
30
+
31
+ config.seed = 0
32
+ config.total_epochs = 90
33
+ config.num_classes = 21843
34
+ config.init_head_bias = -10.0
35
+ config.loss = 'sigmoid_xent'
36
+
37
+ config.input = dict()
38
+ config.input.data = dict(
39
+ name='imagenet21k',
40
+ split='full[51200:]',
41
+ )
42
+ config.input.batch_size = 4096
43
+ config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
44
+
45
+ pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
46
+ pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
47
+ pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
48
+ config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k
49
+ pp_eval = 'decode|resize_small(256)|central_crop(224)'
50
+
51
+ config.log_training_steps = 50
52
+ config.ckpt_steps = 1000
53
+
54
+ # Model section
55
+ config.model_name = 'bit_paper'
56
+ config.model = dict(depth=50, width=1.0)
57
+
58
+ # Optimizer section
59
+ config.optax_name = 'big_vision.momentum_hp'
60
+ config.grad_clip_norm = 1.0
61
+
62
+ # linear scaling rule. Don't forget to sweep if sweeping batch_size.
63
+ config.lr = (0.03 / 256) * config.input.batch_size
64
+ config.wd = (3e-5 / 256) * config.input.batch_size
65
+ config.schedule = dict(decay_type='cosine', warmup_steps=5000)
66
+
67
+ # Evaluations on i21k itself.
68
+ def eval_i21k(split):
69
+ return dict(
70
+ type='classification',
71
+ data={**config.input.data, 'split': split},
72
+ pp_fn=pp_eval + pp_common_i21k,
73
+ loss_name=config.loss,
74
+ log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
75
+ )
76
+ config.evals = {}
77
+ config.evals.test = eval_i21k('full[:25_600]')
78
+ config.evals.val = eval_i21k('full[25_600:51_200]')
79
+ config.evals.train = eval_i21k('full[51_200:76_800]')
80
+
81
+ # Few-shot evaluators
82
+ config.evals.fewshot = get_fewshot_lsr()
83
+ config.evals.fewshot.log_steps = 25_000
84
+
85
+ return config
big_vision/configs/common.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A few things commonly used across A LOT of config files."""
16
+
17
+ import string
18
+
19
+ import ml_collections as mlc
20
+
21
+
22
+ def input_for_quicktest(config_input, quicktest):
23
+ if quicktest:
24
+ config_input.batch_size = 8
25
+ config_input.shuffle_buffer_size = 10
26
+ config_input.cache_raw = False
27
+
28
+
29
+ def parse_arg(arg, lazy=False, **spec):
30
+ """Makes ConfigDict's get_config single-string argument more usable.
31
+
32
+ Example use in the config file:
33
+
34
+ import big_vision.configs.common as bvcc
35
+ def get_config(arg):
36
+ arg = bvcc.parse_arg(arg,
37
+ res=(224, int),
38
+ runlocal=False,
39
+ schedule='short',
40
+ )
41
+
42
+ # ...
43
+
44
+ config.shuffle_buffer = 250_000 if not arg.runlocal else 50
45
+
46
+ Ways that values can be passed when launching:
47
+
48
+ --config amazing.py:runlocal,schedule=long,res=128
49
+ --config amazing.py:res=128
50
+ --config amazing.py:runlocal # A boolean needs no value for "true".
51
+ --config amazing.py:runlocal=False # Explicit false boolean.
52
+ --config amazing.py:128 # The first spec entry may be passed unnamed alone.
53
+
54
+ Uses strict bool conversion (converting 'True', 'true' to True, and 'False',
55
+ 'false', '' to False).
56
+
57
+ Args:
58
+ arg: the string argument that's passed to get_config.
59
+ lazy: allow lazy parsing of arguments, which are not in spec. For these,
60
+ the type is auto-extracted in dependence of most complex possible type.
61
+ **spec: the name and default values of the expected options.
62
+ If the value is a tuple, the value's first element is the default value,
63
+ and the second element is a function called to convert the string.
64
+ Otherwise the type is automatically extracted from the default value.
65
+
66
+ Returns:
67
+ ConfigDict object with extracted type-converted values.
68
+ """
69
+ # Normalize arg and spec layout.
70
+ arg = arg or '' # Normalize None to empty string
71
+ spec = {k: get_type_with_default(v) for k, v in spec.items()}
72
+
73
+ result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only.
74
+
75
+ # Expand convenience-cases for a single parameter without = sign.
76
+ if arg and ',' not in arg and '=' not in arg:
77
+ # (think :runlocal) If it's the name of sth in the spec (or there is no
78
+ # spec), it's that in bool.
79
+ if arg in spec or not spec:
80
+ arg = f'{arg}=True'
81
+ # Otherwise, it is the value for the first entry in the spec.
82
+ else:
83
+ arg = f'{list(spec.keys())[0]}={arg}'
84
+ # Yes, we rely on Py3.7 insertion order!
85
+
86
+ # Now, expand the `arg` string into a dict of keys and values:
87
+ raw_kv = {raw_arg.split('=')[0]:
88
+ raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True'
89
+ for raw_arg in arg.split(',') if raw_arg}
90
+
91
+ # And go through the spec, using provided or default value for each:
92
+ for name, (default, type_fn) in spec.items():
93
+ val = raw_kv.pop(name, None)
94
+ result[name] = type_fn(val) if val is not None else default
95
+
96
+ if raw_kv:
97
+ if lazy: # Process args which are not in spec.
98
+ for k, v in raw_kv.items():
99
+ result[k] = autotype(v)
100
+ else:
101
+ raise ValueError(f'Unhandled config args remain: {raw_kv}')
102
+
103
+ return result
104
+
105
+
106
+ def get_type_with_default(v):
107
+ """Returns (v, string_to_v_type) with lenient bool parsing."""
108
+ # For bool, do safe string conversion.
109
+ if isinstance(v, bool):
110
+ def strict_bool(x):
111
+ assert x.lower() in {'true', 'false', ''}
112
+ return x.lower() == 'true'
113
+ return (v, strict_bool)
114
+ # If already a (default, type) tuple, use that.
115
+ if isinstance(v, (tuple, list)):
116
+ assert len(v) == 2 and isinstance(v[1], type), (
117
+ 'List or tuple types are currently not supported because we use `,` as'
118
+ ' dumb delimiter. Contributions (probably using ast) welcome. You can'
119
+ ' unblock by using a string with eval(s.replace(";", ",")) or similar')
120
+ return (v[0], v[1])
121
+ # Otherwise, derive the type from the default value.
122
+ return (v, type(v))
123
+
124
+
125
+ def autotype(x):
126
+ """Auto-converts string to bool/int/float if possible."""
127
+ assert isinstance(x, str)
128
+ if x.lower() in {'true', 'false'}:
129
+ return x.lower() == 'true' # Returns as bool.
130
+ try:
131
+ return int(x) # Returns as int.
132
+ except ValueError:
133
+ try:
134
+ return float(x) # Returns as float.
135
+ except ValueError:
136
+ return x # Returns as str.
137
+
138
+
139
+ def pack_arg(**kw):
140
+ """Packs key-word args as a string to be parsed by `parse_arg()`."""
141
+ for v in kw.values():
142
+ assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}"
143
+ return ','.join([f'{k}={v}' for k, v in kw.items()])
144
+
145
+
146
+ def arg(**kw):
147
+ """Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg."""
148
+ return {'config_arg': pack_arg(**kw), **kw}
149
+
150
+
151
+ def _get_field_ref(config_dict, field_name):
152
+ path = field_name.split('.')
153
+ for field in path[:-1]:
154
+ config_dict = getattr(config_dict, field)
155
+ return config_dict.get_ref(path[-1])
156
+
157
+
158
+ def format_str(format_string, config):
159
+ """Format string with reference fields from config.
160
+
161
+ This makes it easy to build preprocess strings that contain references to
162
+ fields tha are edited after. E.g.:
163
+
164
+ ```
165
+ config = mlc.ConficDict()
166
+ config.res = (256, 256)
167
+ config.pp = bvcc.format_str('resize({res})', config)
168
+ ...
169
+ # if config.res is modified (e.g. via sweeps) it will propagate to pp field:
170
+ config.res = (512, 512)
171
+ assert config.pp == 'resize((512, 512))'
172
+ ```
173
+
174
+ Args:
175
+ format_string: string to format with references.
176
+ config: ConfigDict to get references to format the string.
177
+
178
+ Returns:
179
+ A reference field which renders a string using references to config fields.
180
+ """
181
+ output = ''
182
+ parts = string.Formatter().parse(format_string)
183
+ for (literal_text, field_name, format_spec, conversion) in parts:
184
+ assert not format_spec and not conversion
185
+ output += literal_text
186
+ if field_name:
187
+ output += _get_field_ref(config, field_name).to_str()
188
+ return output
big_vision/configs/common_fewshot.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Most common few-shot eval configuration."""
16
+
17
+ import ml_collections as mlc
18
+
19
+
20
+ def get_fewshot_lsr(target_resolution=224, resize_resolution=256,
21
+ runlocal=False, pp=None, **kw):
22
+ """Returns a standard-ish fewshot eval configuration."""
23
+ kw.setdefault('representation_layer', 'pre_logits')
24
+ kw.setdefault('shots', (1, 5, 10, 25))
25
+ kw.setdefault('l2_reg', 2.0 ** 10)
26
+ kw.setdefault('num_seeds', 3)
27
+ kw.setdefault('prefix', '') # No prefix as we already use a/ z/ and zz/
28
+
29
+ # Backward-compatible default:
30
+ if not any(f'log_{x}' in kw for x in ['steps', 'percent', 'examples', 'epochs']): # pylint: disable=line-too-long
31
+ kw['log_steps'] = 25_000
32
+
33
+ config = mlc.ConfigDict(kw)
34
+ config.type = 'fewshot_lsr'
35
+ config.datasets = {
36
+ 'caltech': ('caltech101', 'train', 'test'), # copybara:srtip
37
+ 'cars': ('cars196:2.1.0', 'train', 'test'),
38
+ 'cifar100': ('cifar100', 'train', 'test'),
39
+ 'dtd': ('dtd', 'train', 'test'),
40
+ # The first 65000 ImageNet samples have at least 30 shots per any class.
41
+ # Commented out by default because needs manual download.
42
+ # 'imagenet': ('imagenet2012', 'train[:65000]', 'validation'),
43
+ 'pets': ('oxford_iiit_pet', 'train', 'test'),
44
+ 'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'),
45
+ } if not runlocal else {
46
+ 'pets': ('oxford_iiit_pet', 'train', 'test'),
47
+ }
48
+
49
+ pp = pp or '|'.join([
50
+ 'decode',
51
+ f'resize({resize_resolution})',
52
+ f'central_crop({target_resolution})',
53
+ 'value_range(-1,1)'
54
+ ])
55
+ pp += '|keep("image", "label")'
56
+ config.pp_train = pp
57
+ config.pp_eval = pp
58
+ config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)]
59
+
60
+ return config
big_vision/configs/load_and_eval.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pytype: disable=not-writable,attribute-error
16
+ # pylint: disable=line-too-long,missing-function-docstring
17
+ r"""A config to load and eval key model using the core train.py.
18
+
19
+ The runtime varies widely depending on the model, but each one should reproduce
20
+ the corresponding paper's numbers.
21
+ This configuration makes use of the "arg" to get_config to select which model
22
+ to run, so a few examples are given below:
23
+
24
+ Run and evaluate a BiT-M ResNet-50x1 model that was transferred to i1k:
25
+
26
+ big_vision.train \
27
+ --config big_vision/configs/load_and_eval.py:name=bit_paper,batch_size=8 \
28
+ --config.model_init M-imagenet2012 --config.model.width 1 --config.model.depth 50
29
+
30
+ Run and evaluate the recommended ViT-B/32 from "how to train your vit" paper:
31
+
32
+ big_vision.train \
33
+ --config big_vision/configs/load_and_eval.py:name=vit_i21k,batch_size=8 \
34
+ --config.model.variant B/32 --config.model_init howto-i21k-B/32
35
+ """
36
+
37
+ import big_vision.configs.common as bvcc
38
+ from big_vision.configs.common_fewshot import get_fewshot_lsr
39
+
40
+
41
+ def eval_only(config, batch_size, spec_for_init):
42
+ """Set a few configs that turn trainer into (almost) eval-only."""
43
+ config.total_steps = 0
44
+ config.input = {}
45
+ config.input.batch_size = batch_size
46
+ config.input.data = dict(name='bv:dummy', spec=spec_for_init)
47
+ config.optax_name = 'identity'
48
+ config.lr = 0.0
49
+
50
+ config.mesh = [('data', -1)]
51
+ config.sharding_strategy = [('params/.*', 'fsdp(axis="data")')]
52
+ config.sharding_rules = [('act_batch', ('data',))]
53
+
54
+ return config
55
+
56
+
57
+ def get_config(arg=''):
58
+ config = bvcc.parse_arg(arg, name='bit_paper', batch_size=4)
59
+
60
+ # Make the config eval-only by setting some dummies.
61
+ eval_only(config, config.batch_size, spec_for_init=dict(
62
+ image=dict(shape=(224, 224, 3), dtype='float32'),
63
+ ))
64
+
65
+ config.evals = dict(fewshot=get_fewshot_lsr())
66
+
67
+ # Just calls the function with the name given as `config`.
68
+ # Could also be a giant if-block if you're into that kind of thing.
69
+ globals()[config.name](config)
70
+ return config
71
+
72
+
73
+ def bit_paper(config):
74
+ config.num_classes = 1000
75
+
76
+ config.model_name = 'bit_paper'
77
+ config.model_init = 'M-imagenet2012' # M = i21k, -imagenet2012 = fine-tuned
78
+ config.model = dict(width=1, depth=50)
79
+
80
+ def get_eval(split, lbl, dataset='imagenet2012_real'):
81
+ return dict(
82
+ type='classification',
83
+ data=dict(name=dataset, split=split),
84
+ loss_name='softmax_xent',
85
+ cache='none', # Only run once, on low-mem machine.
86
+ pp_fn=(
87
+ 'decode|resize(384)|value_range(-1, 1)'
88
+ f'|onehot(1000, key="{lbl}", key_result="labels")'
89
+ '|keep("image", "labels")'
90
+ ),
91
+ )
92
+ config.evals.test = get_eval('validation', 'original_label')
93
+ config.evals.real = get_eval('validation', 'real_label')
94
+ config.evals.v2 = get_eval('test', 'label', 'imagenet_v2')
95
+
96
+
97
+ def vit_i1k(config):
98
+ config.num_classes = 1000
99
+
100
+ config.model_name = 'vit'
101
+ config.model_init = '' # Will be set in sweep.
102
+ config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
103
+ rep_size=True)
104
+
105
+ config.evals.val = dict(
106
+ type='classification',
107
+ data=dict(name='imagenet2012', split='validation'),
108
+ pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
109
+ loss_name='softmax_xent',
110
+ cache='none', # Only run once, on low-mem machine.
111
+ )
112
+
113
+
114
+ def mlp_mixer_i1k(config):
115
+ config.num_classes = 1000
116
+
117
+ config.model_name = 'mlp_mixer'
118
+ config.model_init = '' # Will be set in sweep.
119
+ config.model = dict(variant='L/16')
120
+
121
+ config.evals.val = dict(
122
+ type='classification',
123
+ data=dict(name='imagenet2012', split='validation'),
124
+ pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
125
+ loss_name='softmax_xent',
126
+ cache='none', # Only run once, on low-mem machine.
127
+ )
128
+
129
+
130
+ def vit_i21k(config):
131
+ config.num_classes = 21843
132
+
133
+ config.model_name = 'vit'
134
+ config.model_init = '' # Will be set in sweep.
135
+ config.model = dict(variant='B/32', pool_type='tok')
136
+
137
+ config.evals.val = dict(
138
+ type='classification',
139
+ data=dict(name='imagenet21k', split='full[:51200]'),
140
+ pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(21843)|keep("image", "labels")',
141
+ loss_name='sigmoid_xent',
142
+ cache='none', # Only run once, on low-mem machine.
143
+ )
big_vision/configs/mlp_mixer_i1k.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""A config for training MLP-Mixer-B/16 model on ILSVRC-2012 ("ImageNet-1k").
17
+
18
+ Achieves 76.3% top-1 accuracy on the test split in 2h11m on TPU v3-128
19
+ with 300 epochs. A shorter 60 epochs run is expected to get to 70.5% in 27m.
20
+
21
+ big_vision.train \
22
+ --config big_vision/configs/mlp_mixer_i1k.py \
23
+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
24
+ """
25
+
26
+ from big_vision.configs.common_fewshot import get_fewshot_lsr
27
+ import ml_collections as mlc
28
+
29
+
30
+ def get_config(mode=None):
31
+ """Config for training Mixer on i1k."""
32
+ config = mlc.ConfigDict()
33
+
34
+ config.seed = 0
35
+ config.total_epochs = 300
36
+ config.num_classes = 1000
37
+ config.loss = 'sigmoid_xent'
38
+ config.init_head_bias = -6.9
39
+
40
+ config.input = dict()
41
+ config.input.data = dict(
42
+ name='imagenet2012',
43
+ split='train[:99%]',
44
+ )
45
+ config.input.batch_size = 4096
46
+ config.input.cache_raw = True # Needs up to 120GB of RAM!
47
+ config.input.shuffle_buffer_size = 250_000
48
+
49
+ config.input.pp = (
50
+ 'decode_jpeg_and_inception_crop(224)'
51
+ '|flip_lr'
52
+ '|randaug(2,15)'
53
+ '|value_range(-1, 1)'
54
+ '|onehot(1000, key="label", key_result="labels")'
55
+ '|keep("image", "labels")'
56
+ )
57
+ pp_eval = (
58
+ 'decode'
59
+ '|resize_small(256)|central_crop(224)'
60
+ '|value_range(-1, 1)'
61
+ '|onehot(1000, key="{lbl}", key_result="labels")'
62
+ '|keep("image", "labels")'
63
+ )
64
+
65
+ # To continue using the near-defunct randaug op.
66
+ config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
67
+
68
+ config.log_training_steps = 50
69
+ config.ckpt_steps = 1000
70
+
71
+ config.prefetch_to_device = 2
72
+
73
+ # Model section
74
+ config.model_name = 'mlp_mixer'
75
+ config.model = dict()
76
+ config.model.variant = 'B/16'
77
+ config.model.stoch_depth = 0.1
78
+
79
+ config.mixup = dict(fold_in=None, p=0.5)
80
+
81
+ # Optimizer section
82
+ config.optax_name = 'scale_by_adam'
83
+ config.grad_clip_norm = 1.
84
+
85
+ config.lr = 0.001
86
+ config.wd = 1e-4
87
+ config.schedule = dict(
88
+ decay_type='linear',
89
+ warmup_steps=10_000,
90
+ linear_end=1e-5,
91
+ )
92
+
93
+ # Eval section
94
+ def get_eval(split, dataset='imagenet2012'):
95
+ return dict(
96
+ type='classification',
97
+ data=dict(name=dataset, split=split),
98
+ pp_fn=pp_eval.format(lbl='label'),
99
+ loss_name=config.loss,
100
+ log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
101
+ cache_final=mode != 'gpu8',
102
+ )
103
+ config.evals = {}
104
+ config.evals.train = get_eval('train[:2%]')
105
+ config.evals.minival = get_eval('train[99%:]')
106
+ config.evals.val = get_eval('validation')
107
+ config.evals.v2 = get_eval('test', dataset='imagenet_v2')
108
+ config.evals.real = get_eval('validation', dataset='imagenet2012_real')
109
+ config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
110
+
111
+ config.fewshot = get_fewshot_lsr()
112
+
113
+ if mode == 'gpu8':
114
+ config.total_epochs = 60
115
+ config.input.batch_size = 512
116
+ config.input.cache_raw = False
117
+ if mode == 'regression_test':
118
+ config.total_epochs = 60
119
+
120
+ return config
big_vision/configs/transfer.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long,missing-function-docstring
16
+ r"""A config for transferring vit-augreg.
17
+
18
+ Best HP selected on (mini)val, expected test results (repeated 5 times):
19
+
20
+ ViT-Augreg-B/32:
21
+ Dataset, crop, learning rate, mean (%), range (%)
22
+ - ImageNet, inception_crop, 0.03, 83.27, [83.22...83.33]
23
+ - Cifar10, resmall_crop, 0.003, 98.55, [98.46...98.6]
24
+ - Cifar100, resmall_crop, 0.01, 91.35, [91.09...91.62]
25
+ - Pets, inception_crop, 0.003, 93.78, [93.62...94.00]
26
+ - Flowers, inception_crop, 0.003, 99.43, [99.42...99.45]
27
+
28
+
29
+ Command to run:
30
+ big_vision.train \
31
+ --config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop \
32
+ --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03
33
+ """
34
+
35
+ import big_vision.configs.common as bvcc
36
+ import ml_collections as mlc
37
+
38
+
39
+ def _set_model(config, model):
40
+ """Load pre-trained models: vit or bit."""
41
+ # Reset the head to init (of zeros) when transferring.
42
+ config.model_load = dict(dont_load=['head/kernel', 'head/bias'])
43
+
44
+ if model == 'vit-i21k-augreg-b/32':
45
+ # Load "recommended" upstream B/32 from https://arxiv.org/abs/2106.10270
46
+ config.model_name = 'vit'
47
+ config.model_init = 'howto-i21k-B/32'
48
+ config.model = dict(variant='B/32', pool_type='tok')
49
+ elif model == 'vit-i21k-augreg-l/16':
50
+ config.model_name = 'vit'
51
+ config.model_init = 'howto-i21k-L/16'
52
+ config.model = dict(variant='L/16', pool_type='tok')
53
+ elif model == 'vit-s16':
54
+ config.model_name = 'vit'
55
+ config.model_init = 'i1k-s16-300ep'
56
+ config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
57
+ rep_size=True)
58
+ elif model == 'bit-m-r50x1':
59
+ config.model_name = 'bit_paper'
60
+ config.model_init = 'M'
61
+ config.model = dict(depth=50, width=1)
62
+ else:
63
+ raise ValueError(f'Unknown model: {model}, please define customized model.')
64
+
65
+
66
+ def _set_dataset(config, dataset, crop='inception_crop', h_res=448, l_res=384):
67
+ if dataset == 'cifar10':
68
+ _set_task(config, 'cifar10', 'train[:98%]', 'train[98%:]', 'test', 10, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
69
+ elif dataset == 'cifar100':
70
+ _set_task(config, 'cifar100', 'train[:98%]', 'train[98%:]', 'test', 100, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
71
+ elif dataset == 'imagenet2012':
72
+ _set_task(config, 'imagenet2012', 'train[:99%]', 'train[99%:]', 'validation', 1000, steps=20_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
73
+ _set_imagenet_variants(config)
74
+ elif dataset == 'oxford_iiit_pet':
75
+ _set_task(config, 'oxford_iiit_pet', 'train[:90%]', 'train[90%:]', 'test', 37, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res)
76
+ elif dataset == 'oxford_flowers102':
77
+ _set_task(config, 'oxford_flowers102', 'train[:90%]', 'train[90%:]', 'test', 102, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res)
78
+ else:
79
+ raise ValueError(
80
+ f'Unknown dataset: {dataset}, please define customized dataset.')
81
+
82
+
83
+ def _set_task(config, dataset, train, val, test, n_cls,
84
+ steps=20_000, warmup=500, lbl='label', crop='resmall_crop',
85
+ flip=True, h_res=448, l_res=384):
86
+ """Vision task with val and test splits."""
87
+ config.total_steps = steps
88
+ config.schedule = dict(
89
+ warmup_steps=warmup,
90
+ decay_type='cosine',
91
+ )
92
+
93
+ config.input.data = dict(name=dataset, split=train)
94
+ pp_common = (
95
+ '|value_range(-1, 1)|'
96
+ f'onehot({n_cls}, key="{lbl}", key_result="labels")|'
97
+ 'keep("image", "labels")'
98
+ )
99
+
100
+ if crop == 'inception_crop':
101
+ pp_train = f'decode|inception_crop({l_res})'
102
+ elif crop == 'resmall_crop':
103
+ pp_train = f'decode|resize_small({h_res})|random_crop({l_res})'
104
+ elif crop == 'resize_crop':
105
+ pp_train = f'decode|resize({h_res})|random_crop({l_res})'
106
+ else:
107
+ raise ValueError(f'Unknown crop: {crop}. Must be one of: '
108
+ 'inception_crop, resmall_crop, resize_crop')
109
+ if flip:
110
+ pp_train += '|flip_lr'
111
+ config.input.pp = pp_train + pp_common
112
+
113
+ pp = f'decode|resize_small({h_res})|central_crop({l_res})' + pp_common
114
+ config.num_classes = n_cls
115
+
116
+ def get_eval(split):
117
+ return dict(
118
+ type='classification',
119
+ data=dict(name=dataset, split=split),
120
+ loss_name='softmax_xent',
121
+ log_steps=100,
122
+ pp_fn=pp,
123
+ )
124
+ config.evals = dict(val=get_eval(val), test=get_eval(test))
125
+
126
+
127
+ def _set_imagenet_variants(config, h_res=448, l_res=384):
128
+ """Evaluation tasks on ImageNet variants: v2 and real."""
129
+ pp = (f'decode|resize_small({h_res})|central_crop({l_res})'
130
+ '|value_range(-1, 1)|onehot(1000, key="{lbl}", key_result="labels")|'
131
+ 'keep("image", "labels")'
132
+ )
133
+
134
+ # Special-case rename for i1k (val+test -> minival+val)
135
+ config.evals.minival = config.evals.val
136
+ config.evals.val = config.evals.test
137
+ # NOTE: keep test == val for convenience in subsequent analysis.
138
+
139
+ config.evals.real = dict(type='classification')
140
+ config.evals.real.data = dict(name='imagenet2012_real', split='validation')
141
+ config.evals.real.pp_fn = pp.format(lbl='real_label')
142
+ config.evals.real.loss_name = config.loss
143
+ config.evals.real.log_steps = 100
144
+
145
+ config.evals.v2 = dict(type='classification')
146
+ config.evals.v2.data = dict(name='imagenet_v2', split='test')
147
+ config.evals.v2.pp_fn = pp.format(lbl='label')
148
+ config.evals.v2.loss_name = config.loss
149
+ config.evals.v2.log_steps = 100
150
+
151
+
152
+ def get_config(arg=None):
153
+ """Config for adaptation."""
154
+ arg = bvcc.parse_arg(arg, model='vit', dataset='cifar10', crop='resmall_crop',
155
+ h_res=448, l_res=384, batch_size=512, fsdp=False,
156
+ runlocal=False)
157
+ config = mlc.ConfigDict()
158
+
159
+ config.input = {}
160
+ config.input.batch_size = arg.batch_size if not arg.runlocal else 8
161
+ config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 100
162
+
163
+ config.log_training_steps = 10
164
+ config.ckpt_steps = 1000
165
+ config.ckpt_timeout = 600
166
+
167
+ # Optimizer section
168
+ config.optax_name = 'big_vision.momentum_hp'
169
+ config.grad_clip_norm = 1.0
170
+ config.wd = None # That's our default, but just being explicit here!
171
+ config.loss = 'softmax_xent'
172
+ config.lr = 0.01
173
+ config.mixup = dict(p=0.0)
174
+
175
+ config.seed = 0
176
+
177
+ _set_dataset(config, arg.dataset, arg.crop, arg.h_res, arg.l_res)
178
+
179
+ _set_model(config, arg.model)
180
+ if arg.fsdp:
181
+ config.mesh = [('data', -1)]
182
+ config.sharding_strategy = [('.*', 'fsdp(axis="data")')]
183
+ config.sharding_rules = [('act_batch', ('data',))]
184
+ config.model.scan = True
185
+
186
+ return config
big_vision/configs/vit_i1k.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Pre-training ViT on ILSVRC-2012 as in https://arxiv.org/abs/2106.10270
17
+
18
+ This config does NOT include regularization (dropout, stochastic depth), which
19
+ was shown to help with B/32, B/16, L/16 models in the paper (Figure 4).
20
+
21
+ This configuration makes use of the "arg" to get_config to select which model
22
+ to run, so a few examples are given below:
23
+
24
+ Run training of a B/16 model:
25
+
26
+ big_vision.train \
27
+ --config big_vision/configs/vit_i1k.py:variant=B/16 \
28
+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
29
+
30
+ Run training of a B/32 model with custom aug-strenght and 300ep:
31
+
32
+ big_vision.train \
33
+ --config big_vision/configs/vit_i1k.py:variant=B/32,aug=light1 \
34
+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
35
+ --config.total_epochs 300
36
+ """
37
+
38
+ import big_vision.configs.common as bvcc
39
+ from big_vision.configs.common_fewshot import get_fewshot_lsr
40
+ import ml_collections as mlc
41
+
42
+ MIXUP_DEF = {
43
+ 'none': dict(p=0.0, fold_in=None),
44
+ 'light1': dict(p=0.0, fold_in=None),
45
+ 'light2': dict(p=0.2, fold_in=None),
46
+ 'medium1': dict(p=0.2, fold_in=None),
47
+ 'medium2': dict(p=0.5, fold_in=None),
48
+ 'strong1': dict(p=0.5, fold_in=None),
49
+ 'strong2': dict(p=0.8, fold_in=None),
50
+ }
51
+
52
+ RANDAUG_DEF = {
53
+ 'none': '',
54
+ 'light1': 'randaug(2,0)', # Actually not nothing!
55
+ 'light2': 'randaug(2,10)',
56
+ 'medium1': 'randaug(2,15)',
57
+ 'medium2': 'randaug(2,15)',
58
+ 'strong1': 'randaug(2,20)',
59
+ 'strong2': 'randaug(2,20)',
60
+ }
61
+
62
+
63
+ def get_config(arg=None):
64
+ """Config for training."""
65
+ arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug='')
66
+ config = mlc.ConfigDict()
67
+
68
+ config.seed = 0
69
+ config.total_epochs = 300
70
+ config.num_classes = 1000
71
+ config.loss = 'sigmoid_xent'
72
+ config.init_head_bias = -6.9
73
+
74
+ # If this gives a KeyError, lookup Fig4 of the paper and add an entry.
75
+ # Note, this here is a good average between 30ep and 300ep, sometimes you coud
76
+ # find a slightly better setting for either of them.
77
+ aug_setting = arg.aug or {
78
+ 'Ti/16': 'light1',
79
+ 'S/32': 'medium1',
80
+ 'S/16': 'medium2',
81
+ 'B/32': 'medium2',
82
+ 'B/16': 'medium2',
83
+ 'L/16': 'medium2',
84
+ }[arg.variant]
85
+
86
+ config.input = dict()
87
+ config.input.data = dict(
88
+ name='imagenet2012',
89
+ split='train[:99%]',
90
+ )
91
+ config.input.batch_size = 4096
92
+ config.input.cache = 'raw_data' if arg.runlocal else 'none' # Needs up to 120GB of RAM!
93
+ config.input.shuffle_buffer_size = 250_000
94
+
95
+ pp_common = (
96
+ '|value_range(-1, 1)'
97
+ '|onehot(1000, key="{lbl}", key_result="labels")'
98
+ '|keep("image", "labels")'
99
+ )
100
+ config.input.pp = (
101
+ 'decode_jpeg_and_inception_crop(224)|flip_lr|' +
102
+ RANDAUG_DEF[aug_setting] +
103
+ pp_common.format(lbl='label')
104
+ )
105
+ pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
106
+
107
+ # To continue using the near-defunct randaug op.
108
+ config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
109
+
110
+ # Aggressive pre-fetching because our models here are small, so we not only
111
+ # can afford it, but we also need it for the smallest models to not be
112
+ # bottle-necked by the input pipeline. Play around with it for -L models tho.
113
+ config.input.prefetch = 8
114
+ config.prefetch_to_device = 4
115
+
116
+ config.log_training_steps = 50
117
+ config.ckpt_steps = 1000
118
+
119
+ # Model section
120
+ config.model_name = 'vit'
121
+ config.model = dict(
122
+ variant=arg.variant,
123
+ rep_size=True,
124
+ pool_type='tok',
125
+ )
126
+
127
+ # Optimizer section
128
+ config.grad_clip_norm = 1.0
129
+ config.optax_name = 'scale_by_adam'
130
+ config.optax = dict(mu_dtype='bfloat16')
131
+ # The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560
132
+ # almost always behaves exactly like adam, but at a fraction of the memory
133
+ # cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a
134
+ # good idea to try it when you are memory-bound!
135
+ # config.optax_name = 'big_vision.scale_by_adafactor'
136
+ # A good flag to play with when hitting instabilities, is the following:
137
+ # config.optax = dict(beta2_cap=0.95)
138
+
139
+ config.lr = 0.001
140
+ config.wd = 0.0001
141
+ config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
142
+
143
+ config.mixup = MIXUP_DEF[aug_setting]
144
+
145
+ # Eval section
146
+ def get_eval(split, dataset='imagenet2012'):
147
+ return dict(
148
+ type='classification',
149
+ data=dict(name=dataset, split=split),
150
+ pp_fn=pp_eval.format(lbl='label'),
151
+ loss_name=config.loss,
152
+ log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
153
+ cache='final_data' if arg.runlocal else 'none',
154
+ )
155
+ config.evals = {}
156
+ config.evals.train = get_eval('train[:2%]')
157
+ config.evals.minival = get_eval('train[99%:]')
158
+ config.evals.val = get_eval('validation')
159
+ config.evals.v2 = get_eval('test', dataset='imagenet_v2')
160
+ config.evals.real = get_eval('validation', dataset='imagenet2012_real')
161
+ config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
162
+
163
+ config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
164
+ config.fewshot.log_steps = 10_000
165
+
166
+ # Make a few things much smaller for quick local debugging testruns.
167
+ if arg.runlocal:
168
+ config.input.shuffle_buffer_size = 10
169
+ config.input.batch_size = 8
170
+ config.input.cache_raw = False
171
+ config.evals.train.data.split = 'train[:16]'
172
+ config.evals.minival.data.split = 'train[:16]'
173
+ config.evals.val.data.split = 'validation[:16]'
174
+ config.evals.v2.data.split = 'test[:16]'
175
+ config.evals.real.data.split = 'validation[:16]'
176
+
177
+ return config
big_vision/configs/vit_i21k.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Pre-training ViT on ImageNet-21k as in https://arxiv.org/abs/2106.10270
17
+
18
+ This config relies on the Imagenet-21k tfds dataset, which is not yet
19
+ available publicly in TFDS. We intend to add the dataset to public TFDS soon,
20
+ and this config will then be runnable.
21
+
22
+ Note that regularization (dropout, stochastic depth) is not currently
23
+ implemented. This was not beneficial for ImageNet-21k pre-trainning.
24
+ """
25
+
26
+ import big_vision.configs.common as bvcc
27
+ from big_vision.configs.common_fewshot import get_fewshot_lsr
28
+ import ml_collections as mlc
29
+
30
+ MIXUP_DEF = {
31
+ 'none': dict(p=0.0, fold_in=None),
32
+ 'light1': dict(p=0.0, fold_in=None),
33
+ 'light2': dict(p=0.2, fold_in=None),
34
+ 'medium1': dict(p=0.2, fold_in=None),
35
+ 'medium2': dict(p=0.5, fold_in=None),
36
+ 'strong1': dict(p=0.5, fold_in=None),
37
+ 'strong2': dict(p=0.8, fold_in=None),
38
+ }
39
+
40
+ RANDAUG_DEF = {
41
+ 'none': '',
42
+ 'light1': 'randaug(2,0)', # Actually not nothing!
43
+ 'light2': 'randaug(2,10)',
44
+ 'medium1': 'randaug(2,15)',
45
+ 'medium2': 'randaug(2,15)',
46
+ 'strong1': 'randaug(2,20)',
47
+ 'strong2': 'randaug(2,20)',
48
+ }
49
+
50
+
51
+ def get_config(arg=None):
52
+ """Config for training."""
53
+ arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug=None)
54
+ config = mlc.ConfigDict()
55
+
56
+ config.seed = 0
57
+ config.total_epochs = 300
58
+ config.num_classes = 21843
59
+ config.init_head_bias = -10.0
60
+ config.loss = 'sigmoid_xent'
61
+
62
+ # If this gives a KeyError, lookup Fig4 of the paper and add an entry.
63
+ # Note, this here is a good average between 30ep and 300ep, sometimes you coud
64
+ # find a slightly better setting for either of them.
65
+ aug_setting = {
66
+ 'Ti/16': 'none',
67
+ 'S/32': 'none',
68
+ 'S/16': 'light1',
69
+ 'B/32': 'light2',
70
+ 'B/16': 'light2',
71
+ 'L/16': 'medium2',
72
+ }[arg.variant]
73
+
74
+ config.input = dict()
75
+ config.input.data = dict(
76
+ name='imagenet21k',
77
+ split='full[51200:]',
78
+ )
79
+ config.input.batch_size = 4096
80
+ config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
81
+
82
+ pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
83
+ pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
84
+ pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
85
+ config.input.pp = f'decode_jpeg_and_inception_crop(224)|flip_lr|{RANDAUG_DEF[aug_setting]}' + pp_common_i21k
86
+ pp_eval = 'decode|resize_small(256)|central_crop(224)'
87
+
88
+ # To continue using the near-defunct randaug op.
89
+ config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
90
+
91
+ # Aggressive pre-fetching because our models here are small, so we not only
92
+ # can afford it, but we also need it for the smallest models to not be
93
+ # bottle-necked by the input pipeline. Play around with it for -L models tho.
94
+ config.input.prefetch = 8
95
+ config.prefetch_to_device = 4
96
+
97
+ config.log_training_steps = 50
98
+ config.ckpt_steps = 1000
99
+
100
+ # Model section
101
+ config.model_name = 'vit'
102
+ config.model = dict(variant=arg.variant, pool_type='gap', posemb='learn')
103
+
104
+ # Optimizer section
105
+ config.optax_name = 'scale_by_adam'
106
+ config.optax = dict(mu_dtype='bfloat16')
107
+ config.grad_clip_norm = 1.0
108
+
109
+ config.lr = 0.001
110
+ config.wd = 0.0001
111
+ config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
112
+
113
+ config.mixup = MIXUP_DEF[aug_setting]
114
+
115
+ # Evaluations on i21k itself.
116
+ def eval_i21k(split):
117
+ return dict(
118
+ type='classification',
119
+ data={**config.input.data, 'split': split},
120
+ pp_fn=pp_eval + pp_common_i21k,
121
+ loss_name=config.loss,
122
+ log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
123
+ )
124
+ config.evals = {}
125
+ config.evals.test = eval_i21k('full[:25_600]')
126
+ config.evals.val = eval_i21k('full[25_600:51_200]')
127
+ config.evals.train = eval_i21k('full[51_200:76_800]')
128
+
129
+ # Few-shot evaluators
130
+ config.evals.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
131
+ config.evals.fewshot.log_steps = 25_000
132
+
133
+ # Make a few things much smaller for quick local debugging testruns.
134
+ if arg.runlocal:
135
+ config.input.shuffle_buffer_size = 10
136
+ config.input.batch_size = 8
137
+ config.evals.test.data.split = 'full[:16]'
138
+ config.evals.train.data.split = 'full[:16]'
139
+ config.evals.val.data.split = 'full[:16]'
140
+ config.evals.i1k_val.data.split = 'validation[:16]'
141
+ config.evals.i1k_v2.data.split = 'test[:16]'
142
+ config.evals.i1k_a.data.split = 'test[:16]'
143
+ config.evals.i1k_r.data.split = 'test[:16]'
144
+
145
+ return config
big_vision/configs/vit_s16_i1k.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Pre-training ViT-S/16 on ILSVRC-2012 following https://arxiv.org/abs/2205.01580.
17
+
18
+ This should take 6-7h to finish 90ep on a TPU-v3-8 and reach 76.5%,
19
+ see the tech report for more details.
20
+
21
+ Command to run:
22
+
23
+ big_vision.train \
24
+ --config big_vision/configs/vit_s16_i1k.py \
25
+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
26
+
27
+ To run for 300ep, add `--config.total_epochs 300` to the command.
28
+ """
29
+
30
+ import ml_collections as mlc
31
+
32
+
33
+ def get_config():
34
+ """Config for training."""
35
+ config = mlc.ConfigDict()
36
+
37
+ config.seed = 0
38
+ config.total_epochs = 90
39
+ config.num_classes = 1000
40
+ config.loss = 'softmax_xent'
41
+
42
+ config.input = {}
43
+ config.input.data = dict(
44
+ name='imagenet2012',
45
+ split='train[:99%]',
46
+ )
47
+ config.input.batch_size = 1024
48
+ config.input.cache_raw = True # Needs up to 120GB of RAM!
49
+ config.input.shuffle_buffer_size = 250_000
50
+
51
+ pp_common = (
52
+ '|value_range(-1, 1)'
53
+ '|onehot(1000, key="{lbl}", key_result="labels")'
54
+ '|keep("image", "labels")'
55
+ )
56
+ config.input.pp = (
57
+ 'decode_jpeg_and_inception_crop(224)|flip_lr|randaug(2,10)' +
58
+ pp_common.format(lbl='label')
59
+ )
60
+ pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
61
+
62
+ # To continue using the near-defunct randaug op.
63
+ config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
64
+
65
+ config.log_training_steps = 50
66
+ config.ckpt_steps = 1000
67
+
68
+ # Model section
69
+ config.model_name = 'vit'
70
+ config.model = dict(
71
+ variant='S/16',
72
+ rep_size=True,
73
+ pool_type='gap',
74
+ posemb='sincos2d',
75
+ )
76
+
77
+ # Optimizer section
78
+ config.grad_clip_norm = 1.0
79
+ config.optax_name = 'scale_by_adam'
80
+ config.optax = dict(mu_dtype='bfloat16')
81
+
82
+ config.lr = 0.001
83
+ config.wd = 0.0001
84
+ config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
85
+
86
+ config.mixup = dict(p=0.2, fold_in=None)
87
+
88
+ # Eval section
89
+ def get_eval(split, dataset='imagenet2012'):
90
+ return dict(
91
+ type='classification',
92
+ data=dict(name=dataset, split=split),
93
+ pp_fn=pp_eval.format(lbl='label'),
94
+ loss_name=config.loss,
95
+ log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
96
+ )
97
+ config.evals = {}
98
+ config.evals.train = get_eval('train[:2%]')
99
+ config.evals.minival = get_eval('train[99%:]')
100
+ config.evals.val = get_eval('validation')
101
+ config.evals.v2 = get_eval('test', dataset='imagenet_v2')
102
+ config.evals.real = get_eval('validation', dataset='imagenet2012_real')
103
+ config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
104
+
105
+ return config
big_vision/datasets/ai2d/ai2d.py ADDED
File without changes
big_vision/datasets/aokvqa/aokvqa.py ADDED
File without changes
big_vision/datasets/chartqa/chartqa.py ADDED
File without changes
big_vision/datasets/coco35l/coco35l.py ADDED
File without changes
big_vision/datasets/core.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Core data functions, dispatch calls to the requested dataset."""
16
+ import importlib
17
+
18
+
19
+ # Note: intentionally not using ABC to avoid forcing implementation of every
20
+ # method, since one can imagine train-only datasets for example.
21
+ class DataSource:
22
+ """The API that any data source should implement."""
23
+
24
+ def get_tfdata(self, ordered, *, process_split=True, allow_cache=True):
25
+ """Creates this data object as a tf.data.Dataset.
26
+
27
+ This will be called separately in each process, and it is up to the dataset
28
+ implementation to shard it accordingly if desired!
29
+
30
+ Args:
31
+ ordered: if True, the dataset should use deterministic ordering, if False
32
+ it may have undefined ordering. Think of True == val, False == train.
33
+ process_split: if False then every process receives the entire dataset
34
+ (e.g. for evaluators running in a single process).
35
+ allow_cache: whether to allow caching the opened data or not.
36
+
37
+ Returns:
38
+ A tf.data.Dataset object.
39
+
40
+ Raises:
41
+ RuntimeError: if not implemented by the dataset, but called.
42
+ """
43
+ raise RuntimeError("not implemented for {self.__class__.__name__}")
44
+
45
+ @property
46
+ def total_examples(self):
47
+ """Returns number of examples in the dataset, regardless of sharding."""
48
+ raise RuntimeError("not implemented for {self.__class__.__name__}")
49
+
50
+ def num_examples_per_process(self):
51
+ """Returns a list of the numer of examples for each process.
52
+
53
+ This is only needed for datasets that should go through make_for_inference.
54
+
55
+ Returns:
56
+ Returns a list of the numer of examples for each process.
57
+
58
+ Ideally, this would always be `[total() / nprocess] * nprocess`, but in
59
+ reality we can almost never perfectly shard a dataset across arbitrary
60
+ number of processes.
61
+
62
+ One alternative option that can work in some cases is to not even shard
63
+ the dataset and thus return `[num_examples()] * nprocess.
64
+
65
+ Raises:
66
+ RuntimeError: if not implemented by the dataset, but called.
67
+ """
68
+ raise RuntimeError("not implemented for {self.__class__.__name__}")
69
+
70
+
71
+ def get(name, **kw):
72
+ if name.startswith("bv:"):
73
+ mod = importlib.import_module(f"big_vision.datasets.{name[3:]}")
74
+ return mod.DataSource(**kw)
75
+ else:
76
+ mod = importlib.import_module("big_vision.datasets.tfds")
77
+ return mod.DataSource(name, **kw)
big_vision/datasets/countbenchqa/countbenchqa.py ADDED
File without changes
big_vision/datasets/docvqa/docvqa.py ADDED
File without changes
big_vision/datasets/gqa/gqa.py ADDED
File without changes
big_vision/datasets/imagenet/class_names.py ADDED
File without changes
big_vision/datasets/infovqa/infovqa.py ADDED
File without changes
big_vision/datasets/jsonl.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simple data input from .jsonl files."""
16
+
17
+ import hashlib
18
+ import json
19
+ from multiprocessing.pool import ThreadPool
20
+ import os
21
+ import tempfile
22
+ import urllib.request
23
+
24
+ from absl import logging
25
+ import big_vision.datasets.core as ds_core
26
+ import jax
27
+ import numpy as np
28
+ import overrides
29
+ import tensorflow as tf
30
+
31
+
32
+ def cached_download(url, dest=None, verbose=True):
33
+ """Download `url` to local file and return path to that, but with caching."""
34
+ # NOTE: there is a small chance of saving corrupted data if the process is
35
+ # interrupted in the middle of writing the file. Then, reading in the input
36
+ # pipeline will fail, and the fix is to nuke the temp folder.
37
+
38
+ # Compute a temp name based on the URL, so we can check if we already
39
+ # downloaded it before.
40
+ dest = dest or os.path.join(tempfile.gettempdir(), "bv")
41
+ os.makedirs(dest, exist_ok=True)
42
+ dest = os.path.join(dest, hashlib.md5(url.encode()).hexdigest())
43
+
44
+ # NOTE: we should use last-modified header to know whether to re-download.
45
+ if os.path.isfile(dest):
46
+ return dest
47
+
48
+ if verbose:
49
+ print(f"\rRetrieving {url} into {dest}", end="", flush=True)
50
+
51
+ with urllib.request.urlopen(url) as f:
52
+ data = f.read()
53
+ with open(dest, "wb+") as f:
54
+ f.write(data)
55
+ return dest
56
+
57
+
58
+ class DataSource(ds_core.DataSource):
59
+ """.jsonl DataSource."""
60
+
61
+ def __init__(self, fname, *, fopen_keys=(), download_keys=(),
62
+ start=0, stop=float("inf")):
63
+ """Create data-source that's jsonl + data files (eg images).
64
+
65
+ This correctly supports multi-host in that each host only reads a subset of
66
+ the dataset automatically. However, currently, all hosts download all items
67
+ if `download_keys` is specified. TODO: b/lbeyer - This can be improved.
68
+
69
+ Args:
70
+ fname: str, the path to the jsonl file that holds the dataset.
71
+ fopen_keys: collection of str or dict, the keys in the dataset whose
72
+ string value actually is a file-path that should be opened and read,
73
+ and its content is what goes into the batch (eg image filenames
74
+ commonly ["image"]).
75
+ If a dict, the values are folders prefixed to the filenames.
76
+ Supports gs:// for reading from buckets.
77
+ download_keys: collection of str, the keys in the dataset whose string
78
+ value actually is a URL from which the file should be downloaded first.
79
+ files are downloaded to a persistent tmp folder using the URL hash as
80
+ filename. If the file already exists, the download is skipped.
81
+ Must be a subset of `fopen_keys`.
82
+ start: int, index of the first row to use; use for slicing the data.
83
+ stop: int or inf, index of the row after the last one to use.
84
+
85
+ Note:
86
+ This simple data input does not allow for nested/hierarchical values,
87
+ or in any way more complicated values like vectors. Use TFDS for that.
88
+
89
+ The way start/stop arguments are used is as in list slicing[start:stop].
90
+ """
91
+ self.examples = []
92
+
93
+ with tf.io.gfile.GFile(fname) as f:
94
+ for i, line in enumerate(f):
95
+ if (start or 0) <= i < (stop or float("inf")):
96
+ try:
97
+ self.examples.append(json.loads(line))
98
+ except json.decoder.JSONDecodeError as e:
99
+ raise ValueError(f"Invalid JSON in line {i}:\n{line}") from e
100
+
101
+ if download_keys:
102
+ for k in download_keys:
103
+ assert k in fopen_keys, (
104
+ f"{k} in download_keys but missing from fopen_keys {fopen_keys}")
105
+
106
+ # TODO: b/lbeyer - use info from trainer instead, move that to utils.
107
+ logging.info( # pylint: disable=logging-fstring-interpolation
108
+ f"\u001b[33mNOTE\u001b[0m: Downloading {download_keys} "
109
+ f"for dataset {fname} ({len(self.examples)} examples) ...")
110
+
111
+ def _dl_one(ex):
112
+ for k in download_keys:
113
+ ex[k] = cached_download(ex[k])
114
+
115
+ ThreadPool(100).map(_dl_one, self.examples)
116
+ print("Done")
117
+ logging.info("\u001b[33mNOTE\u001b[0m: Done downloading.")
118
+
119
+ # Normalize.
120
+ if isinstance(fopen_keys, (list, tuple)):
121
+ self.fopen_keys = {k: "" for k in fopen_keys}
122
+ else:
123
+ self.fopen_keys = fopen_keys or {}
124
+
125
+ # We need to apply fopen path prefix here already, because doing so while
126
+ # actually reading the files in TF, things are symbolic :(
127
+ for ex in self.examples:
128
+ for k, dirname in self.fopen_keys.items():
129
+ ex[k] = os.path.join(dirname, ex[k])
130
+
131
+ def _indices(self, *, process_split=True, process_index=None):
132
+ indices = np.arange(len(self.examples))
133
+
134
+ if not process_split:
135
+ return list(indices)
136
+
137
+ pid = jax.process_index() if process_index is None else process_index
138
+ return list(np.array_split(indices, jax.process_count())[pid])
139
+
140
+ @overrides.overrides
141
+ def get_tfdata(self, ordered=False, *, process_split=True, allow_cache=True):
142
+ del allow_cache # We don't cache anything anyways.
143
+ assert not process_split or len(self.examples) >= jax.process_count(), (
144
+ "Process splitting the data with fewer examples than processes!?")
145
+
146
+ my_idxs = self._indices(process_split=process_split)
147
+ if not ordered:
148
+ np.random.shuffle(my_idxs)
149
+
150
+ dataset = tf.data.Dataset.from_generator(
151
+ generator=lambda: ({"id": str(i), **self.examples[i]} for i in my_idxs),
152
+ output_signature={
153
+ "id": _guess_signature("0"),
154
+ **{k: _guess_signature(v) for k, v in self.examples[0].items()},
155
+ })
156
+
157
+ def _read_files(example):
158
+ for k in self.fopen_keys:
159
+ example[k] = tf.io.read_file(example[k])
160
+ return example
161
+ dataset = dataset.map(_read_files)
162
+
163
+ return dataset
164
+
165
+ @property
166
+ @overrides.overrides
167
+ def total_examples(self):
168
+ return len(self.examples)
169
+
170
+ @overrides.overrides
171
+ def num_examples_per_process(self):
172
+ return [len(self._indices(process_index=pid))
173
+ for pid in range(jax.process_count())]
174
+
175
+
176
+ def _guess_signature(value):
177
+ return tf.TensorSpec.from_tensor(tf.constant(value))
big_vision/datasets/nocaps/nocaps.py ADDED
File without changes
big_vision/datasets/okvqa/okvqa.py ADDED
File without changes
big_vision/datasets/pope/pope.py ADDED
File without changes
big_vision/datasets/refcoco/refcoco.py ADDED
File without changes
big_vision/datasets/rsvqa_hr/rsvqa_hr.py ADDED
File without changes
big_vision/datasets/rsvqa_lr/rsvqa_lr.py ADDED
File without changes
big_vision/datasets/scicap/scicap.py ADDED
File without changes
big_vision/datasets/science_qa/science_qa.py ADDED
File without changes
big_vision/datasets/screen2words/screen2words.py ADDED
File without changes
big_vision/datasets/sequence_packing.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Packed Sequence Op."""
16
+
17
+ # Forked from
18
+ # https://github.com/google/maxtext/blob/main/MaxText/sequence_packing.py.
19
+
20
+
21
+ from typing import Dict, Optional, List, Union
22
+
23
+ from flax import traverse_util
24
+ import tensorflow as tf
25
+
26
+ AUTOTUNE = tf.data.experimental.AUTOTUNE
27
+ FLATTEN_SEPARATOR = "<|sep|>"
28
+
29
+
30
+ def pack_dataset(
31
+ dataset: tf.data.Dataset,
32
+ batch_size: int | None,
33
+ key2length: Union[int, Dict[str, int]],
34
+ keys: Optional[List[str | tuple[str, ...]]] = None) -> tf.data.Dataset:
35
+ """Creates a 'packed' version of a dataset on-the-fly.
36
+
37
+ Wrap `tensorflow.grain` ops.
38
+
39
+ This is meant to replace the irritation of having to create a separate
40
+ "packed" version of a dataset to train efficiently on TPU.
41
+ Each example in the output dataset represents several examples in the
42
+ input dataset.
43
+
44
+ For each key in the input dataset, two additional keys are created:
45
+ <key>_segment_ids: an int32 tensor identifying the parts
46
+ representing the original example.
47
+ <key>_positions: an int32 tensor identifying the position within the original
48
+ example.
49
+
50
+ Example:
51
+ Two input examples get combined to form an output example.
52
+ The input examples are:
53
+ {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
54
+ {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
55
+ The output example is:
56
+ {
57
+ "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
58
+ "inputs_seg": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
59
+ "inputs_pos": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
60
+ "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
61
+ "targets_seg": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
62
+ "targets_pos": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
63
+ }
64
+ 0 represents padding in both the inputs and the outputs.
65
+ Sequences in the incoming examples are truncated to length "length", and the
66
+ sequences in the output examples all have fixed (padded) length "length".
67
+
68
+ Args:
69
+ dataset: A `tf.data.Dataset`.
70
+ batch_size: Batch size of the packed dataset.
71
+ key2length: An integer, or a dict from feature-key to integer.
72
+ keys: A list of strings (e.g. ["inputs", "targets"]).
73
+
74
+ Returns:
75
+ A `tf.data.Dataset`.
76
+ """
77
+ raise ValueError("Not implemented in OSS yet.")
big_vision/datasets/stvqa/stvqa.py ADDED
File without changes
big_vision/datasets/tallyqa/tallyqa.py ADDED
File without changes
big_vision/datasets/textcaps/textcaps.py ADDED
File without changes