diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..955c68ac938e4ec763a4d51175c65c9989500b06 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +ckpts/clip_vit_l14_with_masks_6c17944 filter=lfs diff=lfs merge=lfs -text +ckpts/owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05_209b65b filter=lfs diff=lfs merge=lfs -text +ckpts/owl2-l14-1008-st-ngrams-ft-lvisbase-ens-cold-weight-04_8ca674c filter=lfs diff=lfs merge=lfs -text +images/scenic_design.jpg filter=lfs diff=lfs merge=lfs -text +images/scenic_logo.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..797fbaaf978606e1991e51e05be7813862584c5e --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,32 @@ +# How to Contribute + +Scenic is a platform used for developing new methods and ideas by Google +researchers, mostly around attention-based models for computer vision or +multi-modal applications. We encourage forking the repository and continued +development. We welcome suggestions and contributions to improving Scenic. +There are a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement (CLA). You (or your employer) retain the copyright to your +contribution; this simply gives us permission to use and redistribute your +contributions as part of the project. Head over to + to see your current agreements on file or +to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code Reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google/conduct/). diff --git a/IOU_test.py b/IOU_test.py new file mode 100644 index 0000000000000000000000000000000000000000..eeac3327d3bb17c7aa88d7022d4388e4c6975e0a --- /dev/null +++ b/IOU_test.py @@ -0,0 +1,21 @@ +from owlv2_helper_functions import get_iou, boxes_filter + +boxes = [ + (128.56, 4.57, 732.52, 476.05), + (569.65, 185.71, 740.31, 244.76), + (569.65, 185.71, 740.31, 244.76), + (569.65, 185.71, 740.31, 244.76), + (101.99, 99.00, 720.12, 88.63), + ] + +scores = [1.0, 0.99, 0.89, 1.0, 0.99] + +instances = ['cat', 'dog', 'dog', 'tiger', 'cat'] + + + +pred_bboxes, pred_scores, instances = boxes_filter(boxes, scores, instances) + +print(pred_bboxes) +print(pred_scores) +print(instances) \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4aa107e4a89900a7f3faa873e6b07282b1c1ff7a --- /dev/null +++ b/README.md @@ -0,0 +1,217 @@ +# Scenic +
+scenic logo +
+ +*Scenic* is a codebase with a focus on research around attention-based models +for computer vision. Scenic has been successfully used to develop +classification, segmentation, and detection models for multiple modalities +including images, video, audio, and multimodal combinations of them. + +More precisely, *Scenic* is a (i) set of shared light-weight libraries solving +tasks commonly encountered tasks when training large-scale (i.e. multi-device, +multi-host) vision models; and (ii) several *projects* containing fully +fleshed out problem-specific training and evaluation loops using these +libraries. + +Scenic is developed in [JAX](https://github.com/jax-ml/jax) and uses +[Flax](https://github.com/google/flax). + +### Contents +* [What we offer](#what-we-offer) +* [SOTA models and baselines in Scenic](#sota-models-and-baselines-in-scenic) +* [Philosophy](#philosophy) +* [Getting started](#getting-started) +* [Scenic component design](#scenic-component-design) +* [Citing Scenic](#citing-scenic) + +## What we offer +Among others *Scenic* provides + +* Boilerplate code for launching experiments, summary writing, logging, + profiling, etc; +* Optimized training and evaluation loops, losses, metrics, bi-partite matchers, + etc; +* Input-pipelines for popular vision datasets; +* [Baseline models](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines#scenic-baseline-models), +including strong non-attentional baselines. + + +## SOTA models and baselines in *Scenic* +There are some SOTA models and baselines in Scenic which were either developed +using Scenic, or have been reimplemented in Scenic: + +Projects that were developed in Scenic or used it for their experiments: + +* [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) +* [OmniNet: Omnidirectional Representations from Transformers](https://arxiv.org/abs/2103.01075) +* [Attention Bottlenecks for Multimodal Fusion](https://arxiv.org/abs/2107.00135) +* [TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?](https://arxiv.org/abs/2106.11297) +* [Exploring the Limits of Large Scale Pre-training](https://arxiv.org/abs/2110.02095) +* [The Efficiency Misnomer](https://arxiv.org/abs/2110.12894) +* [Discrete Representations Strengthen Vision Transformer Robustness](https://arxiv.org/abs/2111.10493) +* [Pyramid Adversarial Training Improves ViT Performance](https://arxiv.org/abs/2111.15121) +* [VUT: Versatile UI Transformer for Multi-Modal Multi-Task User Interface Modeling](https://arxiv.org/abs/2112.05692) +* [CLAY: Learning to Denoise Raw Mobile UI Layouts for Improving Datasets at Scale](https://arxiv.org/abs/2201.04100) +* [Zero-Shot Text-Guided Object Generation with Dream Fields](https://arxiv.org/abs/2112.01455) +* [Multiview Transformers for Video Recognition](https://arxiv.org/abs/2201.04288) +* [PolyViT: Co-training Vision Transformers on Images, Videos and Audio](https://arxiv.org/abs/2111.12993) +* [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) +* [Learning with Neighbor Consistency for Noisy Labels](https://arxiv.org/abs/2202.02200) +* [Token Turing Machines](https://arxiv.org/pdf/2211.09119.pdf) +* [Vid2Seq: Large-Scale Pretraining of a Visual Language Model for Dense Video Captioning](https://arxiv.org/pdf/2302.14115.pdf) +* [AVATAR: Unconstrained Audiovisual Speech Recognition](https://arxiv.org/abs/2206.07684) +* [Adaptive Computation with Elastic Input Sequence](https://arxiv.org/abs/2301.13195) +* [Location-Aware Self-Supervised Transformers for Semantic Segmentation](https://arxiv.org/abs/2212.02400) +* [How can objects help action recognition?](https://openaccess.thecvf.com/content/CVPR2023/html/Zhou_How_Can_Objects_Help_Action_Recognition_CVPR_2023_paper.html) +* [Verbs in Action: Improving verb understanding in video-language models](https://arxiv.org/abs/2304.06708) +* [Unified Visual Relationship Detection with Vision and Language Models](https://arxiv.org/abs/2303.08998) +* [UnLoc: A Unified Framework for Video Localization Tasks](https://arxiv.org/abs/2308.11062) +* [REVEAL: Retrieval-Augmented Visual-Language Pre-Training with Multi-Source Multimodal Knowledge Memory](https://arxiv.org/abs/2212.05221) +* [Audiovisual Masked Autoencoders](https://arxiv.org/abs/2212.05922) +* [MatFormer: Nested Transformer for Elastic Inference](https://arxiv.org/abs/2310.07707) +* [Pixel Aligned Language Models](https://arxiv.org/abs/2312.09237) +* [A Generative Approach for Wikipedia-Scale Visual Entity Recognition](https://arxiv.org/abs/2403.02041) +* [Streaming Dense Video Captioning](https://arxiv.org/abs/2404.01297) +* [Dense Video Object Captioning from Disjoint Supervision](https://arxiv.org/abs/2306.11729) + +More information can be found in [projects](https://github.com/google-research/scenic/tree/main/scenic/projects#list-of-projects-hosted-in-scenic). + +Baselines that were reproduced in Scenic: + +* [(ViT) An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) +* [(DETR) End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) +* [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159) +* [(CLIP) Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) +* [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601) +* [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) +* [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270) +* [Big Transfer (BiT): General Visual Representation Learning](https://arxiv.org/abs/1912.11370) +* [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) +* [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) +* [PCT: Point Cloud Transformer](https://arxiv.org/abs/2012.09688) +* [Universal Transformers](https://arxiv.org/abs/1807.03819) +* [PonderNet](https://arxiv.org/abs/2107.05407) +* [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) +* [Rethinking Attention with Performers](https://arxiv.org/abs/2009.14794) +* [(CenterNet) Objects as Points](https://arxiv.org/abs/1904.07850) +* [(SAM) Segment Anything](https://arxiv.org/abs/2304.02643) + + +More information can be found in [baseline models](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines#scenic-baseline-models). + + +## Philosophy +*Scenic* aims to facilitate rapid prototyping of large-scale vision models. To +keep the code simple to understand and extend we prefer *forking and +copy-pasting over adding complexity or increasing abstraction*. Only when +functionality proves to be widely useful across many models and tasks it may be +upstreamed to Scenic's shared libraries. + + + +## Getting started +* See `projects/baselines/README.md` for a walk-through baseline models and + instructions on how to run the code. +* If you would like to contribute to *Scenic*, please check out the + [Philisophy](#philosophy), [Code structure](#code_structure) and + [Contributing](CONTRIBUTING.md) sections. + Should your contribution be a part of the shared libraries, please send us a + pull request! + + +### Quickstart +You will need Python 3.9 or later. Download the code from GitHub + +```shell +$ git clone https://github.com/google-research/scenic.git +$ cd scenic +$ pip install . +``` + +and run training for ViT on ImageNet: + +```shell +$ python scenic/main.py -- \ + --config=scenic/projects/baselines/configs/imagenet/imagenet_vit_config.py \ + --workdir=./ +``` + +Note that for specific projects and baselines, you might need to install extra +packages that are mentioned in their `README.md` or `requirements.txt` files. + +[Here](https://colab.research.google.com/github/google-research/scenic/blob/main/scenic/common_lib/colabs/scenic_playground.ipynb) +is also a minimal colab to train a simple feed-forward model using Scenic. + + +## Scenic component design +Scenic is designed to propose different levels of abstraction, to support +hosting projects that only require changing hyper-parameters by defining config +files, to those that need customization on the input pipeline, model +architecture, losses and metrics, and the training loop. To make this happen, +the code in Scenic is organized as either _project-level_ code, +which refers to customized code for specific projects or baselines or +_library-level_ code, which refers to common functionalities and general +patterns that are adapted by the majority of projects. The project-level +code lives in the `projects` directory. + +
+scenic design +
+ +### Library-level code +The goal is to keep the library-level code minimal and well-tested and to avoid +introducing extra abstractions to support minor use-cases. Shared libraries +provided by *Scenic* are split into: + +* `dataset_lib`: Implements IO pipelines for loading and pre-processing data + for common Computer Vision tasks and benchmarks (see "Tasks and Datasets" + section). All pipelines are designed to be scalable and support multi-host + and multi-device setups, taking care dividing data among multiple hosts, + incomplete batches, caching, pre-fetching, etc. +* `model_lib` : Provides + * several abstract model interfaces (e.g. `ClassificationModel` or + `SegmentationModel` in `model_lib.base_models`) with task-specific + losses and metrics; + * neural network layers in `model_lib.layers`, focusing on efficient + implementation of attention and transformer layers; + * accelerator-friendly implementations of bipartite matching + algorithms in `model_lib.matchers`. +* `train_lib`: Provides tools for constructing training loops and implements + several optimized trainers (classification trainer and segmentation trainer) + that can be forked for customization. +* `common_lib`: General utilities, like logging and debugging modules, + functionalities for processing raw data, etc. + +### Project-level code +Scenic supports the development of customized solutions for customized tasks and +data via the concept of "project". There is no one-fits-all recipe for how much +code should be re-used by a project. Projects can consist of only configs and +use the common models, trainers, task/data that live in library-level code, or +they can simply fork any of the mentioned functionalities and redefine, layers, +losses, metrics, logging methods, tasks, architectures, as well as training and +evaluation loops. The modularity of library-level code makes it flexible for +projects to fall placed on any spot in the "run-as-is" to "fully customized" +spectrum. + +Common baselines such as a ResNet and Vision Transformer (ViT) are implemented +in the [`projects/baselines`](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines) +project. Forking models in this directory is a good starting point for new +projects. + + +## Citing Scenic +If you use Scenic, you can cite our [white paper](https://openaccess.thecvf.com/content/CVPR2022/html/Dehghani_Scenic_A_JAX_Library_for_Computer_Vision_Research_and_Beyond_CVPR_2022_paper.html). +Here is an example BibTeX entry: + +```bibtex +@InProceedings{dehghani2021scenic, + author = {Dehghani, Mostafa and Gritsenko, Alexey and Arnab, Anurag and Minderer, Matthias and Tay, Yi}, + title = {Scenic: A JAX Library for Computer Vision Research and Beyond}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2022}, + pages = {21393-21398} +} +``` + +_Disclaimer: This is not an official Google product._ diff --git a/__pycache__/owlv2_helper.cpython-310.pyc b/__pycache__/owlv2_helper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b44736907c28f750366c6e79ae285077dcbf117a Binary files /dev/null and b/__pycache__/owlv2_helper.cpython-310.pyc differ diff --git a/__pycache__/owlv2_helper_functions.cpython-310.pyc b/__pycache__/owlv2_helper_functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99000791a4d0f83562b1cdba662dcbaac69325ca Binary files /dev/null and b/__pycache__/owlv2_helper_functions.cpython-310.pyc differ diff --git a/auto_bbox.py b/auto_bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..4ceb1bd8dc03c5d3405e4086bd6c06dc498e09ba --- /dev/null +++ b/auto_bbox.py @@ -0,0 +1,266 @@ +import os +import sys +import cv2 +import json +import glob +import argparse +import subprocess +from typing import List, Tuple, Dict, Any + +import numpy as np +from tqdm import tqdm + + +# ----------------- Args ----------------- +def parse_args(): + ap = argparse.ArgumentParser("OWLv2 detection on JPG folders (Top-K per image), multi-GPU.") + ap.add_argument("--input_dir", type=str, required=True, help="Root that contains subfolders of JPGs; if JPGs are directly under input_dir, it will be treated as a single set.") + ap.add_argument("--startswith", type=str, default="", help="Filter folder name prefix (or input_dir basename if no subfolders).") + ap.add_argument("--output_dir", type=str, required=True) + ap.add_argument("--frame_stride", type=int, default=1, help="Sample every N-th image within a folder.") + ap.add_argument("--top_k", type=int, default=5) + ap.add_argument("--max_frames", type=int, default=0, help="Max processed images per folder; 0 means no limit.") + ap.add_argument("--num_workers", type=int, default=1, help="#GPUs/#workers") + ap.add_argument("--worker_idx", type=int, default=-1, help="internal; >=0 means child worker") + ap.add_argument("--shard_file", type=str, default="", help="internal; JSON with folder paths for this worker") + ap.add_argument("--scenic_root", type=str, default="/home/ubuntu/rs/JiT/VisionModels/Scenic_OWLv2/big_vision") + return ap.parse_args() + + +# ----------------- Utils ----------------- +def _has_jpgs(path: str) -> bool: + exts = ("*.jpg", "*.jpeg", "*.JPG", "*.JPEG") + for pat in exts: + if glob.glob(os.path.join(path, pat)): + return True + return False + + +def iter_image_dirs(input_dir: str, startswith: str) -> List[str]: + """ + Returns a list of directories to process. + - If input_dir contains subfolders: return subfolders that contain JPGs and match startswith. + - Else if input_dir itself contains JPGs and its basename matches startswith: return [input_dir]. + """ + input_dir = os.path.abspath(input_dir) + subs = sorted([p for p in glob.glob(os.path.join(input_dir, "*")) if os.path.isdir(p)]) + # Prefer subfolders if any exist and contain jpgs + dirs = [d for d in subs if os.path.basename(d).startswith(startswith) and _has_jpgs(d)] + if dirs: + return dirs + + # Fallback: treat input_dir itself as one set if it has jpgs + base_ok = os.path.basename(os.path.normpath(input_dir)).startswith(startswith) + if base_ok and _has_jpgs(input_dir): + return [input_dir] + return [] + + +def ensure_dir(p: str): + os.makedirs(p, exist_ok=True) + + +def draw_single_box(frame_bgr: np.ndarray, box: List[float], color=(0, 255, 0), thickness=2) -> np.ndarray: + x1, y1, x2, y2 = map(int, box) + out = frame_bgr.copy() + cv2.rectangle(out, (x1, y1), (x2, y2), color, thickness) + return out + + +def list_images_sorted(folder: str) -> List[str]: + pats = ["*.jpg", "*.jpeg", "*.JPG", "*.JPEG"] + files = [] + for pat in pats: + files.extend(glob.glob(os.path.join(folder, pat))) + # Sort by natural file name order + return sorted(files) + + +# ----------------- Worker logic (imports JAX/Scenic inside) ----------------- +def worker_run(args, dir_paths: List[str]): + import sys as _sys + if args.scenic_root not in _sys.path: + _sys.path.append(args.scenic_root) + + # Free TF GPU to JAX in this process (why: avoid TF reserving VRAM) + import tensorflow as tf + tf.config.experimental.set_visible_devices([], "GPU") + + from scenic.projects.owl_vit import configs + from scenic.projects.owl_vit import models + import jax + import functools + import owlv2_helper as helper # must be available in PYTHONPATH + + class OWLv2Objectness: + def __init__(self, top_k: int = 5): + self.top_k = top_k + self.config = configs.owl_v2_clip_b16.get_config(init_mode="canonical_checkpoint") + self.module = models.TextZeroShotDetectionModule( + body_configs=self.config.model.body, + objectness_head_configs=self.config.model.objectness_head, + normalize=self.config.model.normalize, + box_bias=self.config.model.box_bias, + ) + self.variables = self.module.load_variables(self.config.init_from.checkpoint_path) + + self.image_embedder = jax.jit( + functools.partial(self.module.apply, self.variables, train=False, method=self.module.image_embedder) + ) + self.objectness_predictor = jax.jit( + functools.partial(self.module.apply, self.variables, method=self.module.objectness_predictor) + ) + self.box_predictor = jax.jit( + functools.partial(self.module.apply, self.variables, method=self.module.box_predictor) + ) + + def detect(self, image_bgr: np.ndarray) -> List[Tuple[List[float], float]]: + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + processed = helper.preprocess_images([image_rgb], self.config.dataset_configs.input_size)[0] + feature_map = self.image_embedder(processed[None, ...]) + b, h, w, d = feature_map.shape + image_features = feature_map.reshape(b, h * w, d) + + obj_logits = self.objectness_predictor(image_features)["objectness_logits"] + raw_boxes = self.box_predictor(image_features=image_features, feature_map=feature_map)["pred_boxes"] + + obj = np.array(obj_logits[0], dtype=np.float32) + raw_boxes = np.array(raw_boxes[0], dtype=np.float32) + boxes = helper.rescale_detection_box(raw_boxes, image_rgb) + + if len(obj) == 0: + return [] + + k = min(self.top_k, len(obj)) + thresh = np.partition(obj, -k)[-k] + + filtered: List[Tuple[List[float], float]] = [] + H, W = image_rgb.shape[:2] + for box, score in zip(boxes, obj): + if score < thresh: + continue + if helper.too_small(box) or helper.too_large(box, image_rgb): + continue + x1, y1, x2, y2 = box + x1 = max(0, min(float(x1), W - 1)) + y1 = max(0, min(float(y1), H - 1)) + x2 = max(0, min(float(x2), W - 1)) + y2 = max(0, min(float(y2), H - 1)) + filtered.append(([x1, y1, x2, y2], float(score))) + + kept_boxes = helper.remove_overlapping_bboxes([b for b, _ in filtered]) + + def _match_score(bb: List[float]) -> float: + arr = np.array([b for b, _ in filtered], dtype=np.float32) + idx = int(np.argmin(np.abs(arr - np.array(bb, dtype=np.float32)).sum(axis=1))) + return filtered[idx][1] + + return [(bb, _match_score(bb)) for bb in kept_boxes] + + detector = OWLv2Objectness(top_k=args.top_k) + + for dpath in tqdm(dir_paths, desc=f"Worker{args.worker_idx}", unit="set"): + stem = os.path.basename(os.path.normpath(dpath)) + images = list_images_sorted(dpath) + if not images: + print(f"[WARN][w{args.worker_idx}] No JPGs under: {dpath}") + continue + + saved_cnt = 0 + pbar = tqdm(total=len(images), desc=f"{stem}[w{args.worker_idx}]", unit="img", leave=False) + + for idx, ipath in enumerate(images): + pbar.update(1) + if args.frame_stride > 1 and (idx % args.frame_stride) != 0: + continue + + frame = cv2.imread(ipath, cv2.IMREAD_COLOR) + if frame is None: + print(f"[WARN][w{args.worker_idx}] Cannot read: {ipath}") + continue + + boxes_scores = detector.detect(frame) + if boxes_scores: + boxes_scores = sorted(boxes_scores, key=lambda x: x[1], reverse=True)[:args.top_k] + + fname = os.path.basename(ipath) + for i, (box, score) in enumerate(boxes_scores): + out_dir = os.path.join(args.output_dir, stem, f"object_{i}") + ensure_dir(out_dir) + vis = draw_single_box(frame, box, color=(0, 255, 0), thickness=2) + cv2.imwrite(os.path.join(out_dir, fname), vis) + + saved_cnt += 1 + if args.max_frames and saved_cnt >= args.max_frames: + break + + pbar.close() + + +# ----------------- Master ----------------- +def main(): + args = parse_args() + + # Child worker path + if args.worker_idx >= 0: + if not args.shard_file or not os.path.exists(args.shard_file): + raise RuntimeError("Worker requires --shard_file with JSON list of folder paths.") + with open(args.shard_file, "r", encoding="utf-8") as f: + dir_paths = json.load(f) + worker_run(args, dir_paths) + return + + # Master path + dir_paths = iter_image_dirs(args.input_dir, args.startswith) + if not dir_paths: + print(f"[INFO] No JPG folders (or JPGs) startwith '{args.startswith}' under {args.input_dir}") + return + + num_workers = max(1, int(args.num_workers)) + shards: List[List[str]] = [[] for _ in range(num_workers)] + for i, d in enumerate(dir_paths): + shards[i % num_workers].append(d) + + procs = [] + tmp_dir = os.path.join(args.output_dir, "_shards_tmp") + ensure_dir(tmp_dir) + + for w in range(num_workers): + shard_path = os.path.join(tmp_dir, f"shard_{w}.json") + with open(shard_path, "w", encoding="utf-8") as f: + json.dump(shards[w], f, ensure_ascii=False, indent=0) + + # Bind GPU: cycle through available GPU ids [0..num_workers-1] + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(w) # one GPU per worker + + cmd = [ + sys.executable, __file__, + "--input_dir", args.input_dir, + "--startswith", args.startswith, + "--output_dir", args.output_dir, + "--frame_stride", str(args.frame_stride), + "--top_k", str(args.top_k), + "--max_frames", str(args.max_frames), + "--num_workers", str(num_workers), + "--worker_idx", str(w), + "--shard_file", shard_path, + "--scenic_root", args.scenic_root, + ] + print(f"[Master] Launch worker {w}: GPU={env['CUDA_VISIBLE_DEVICES']} folders={len(shards[w])}") + procs.append(subprocess.Popen(cmd, env=env)) + + # wait + rc = 0 + for p in procs: + p.wait() + rc |= p.returncode + + if rc != 0: + print("[Master] Some workers failed. Return code:", rc) + else: + print("[Master] All workers done. Output:", args.output_dir) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/big_vision/.gitignore b/big_vision/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ed8ebf583f771da9150c35db3955987b7d757904 --- /dev/null +++ b/big_vision/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/big_vision/CONTRIBUTING.md b/big_vision/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..5e5644093c15ecd61b0d0308990855dcbb320a2e --- /dev/null +++ b/big_vision/CONTRIBUTING.md @@ -0,0 +1,26 @@ +# How to Contribute + +At this time we do not plan to accept non-trivial contributions. The main +purpose of this codebase is to allow the community to reproduce results from our +publications. + +You are however free to start a fork of the project for your purposes as +permitted by the license. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement (CLA). You (or your employer) retain the copyright to your +contribution; this simply gives us permission to use and redistribute your +contributions as part of the project. Head over to + to see your current agreements on file or +to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google/conduct/). diff --git a/big_vision/LICENSE b/big_vision/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac --- /dev/null +++ b/big_vision/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/big_vision/README.md b/big_vision/README.md new file mode 100644 index 0000000000000000000000000000000000000000..289fd7401a5ad6be14b6f7f8d8cf3b157a6dfc9f --- /dev/null +++ b/big_vision/README.md @@ -0,0 +1,499 @@ +# Big Vision + +This codebase is designed for training large-scale vision models using +[Cloud TPU VMs](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms) +or GPU machines. It is based on [Jax](https://github.com/google/jax)/[Flax](https://github.com/google/flax) +libraries, and uses [tf.data](https://www.tensorflow.org/guide/data) and +[TensorFlow Datasets](https://www.tensorflow.org/datasets) for scalable and +reproducible input pipelines. + +The open-sourcing of this codebase has two main purposes: +1. Publishing the code of research projects developed in this codebase (see a + list below). +2. Providing a strong starting point for running large-scale vision experiments + on GPU machines and Google Cloud TPUs, which should scale seamlessly and + out-of-the box from a single TPU core to a distributed setup with up to 2048 + TPU cores. + +`big_vision` aims to support research projects at Google. We are unlikely to +work on feature requests or accept external contributions, unless they were +pre-approved (ask in an issue first). For a well-supported transfer-only +codebase, see also [vision_transformer](https://github.com/google-research/vision_transformer). + +Note that `big_vision` is quite dynamic codebase and, while we intend to keep +the core code fully-functional at all times, we can not guarantee timely updates +of the project-specific code that lives in the `.../proj/...` subfolders. +However, we provide a [table](#project-specific-commits) with last known +commits where specific projects were known to work. + +The following research projects were originally conducted in the `big_vision` +codebase: + +### Architecture research + +- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929), by + Alexey Dosovitskiy*, Lucas Beyer*, Alexander Kolesnikov*, Dirk Weissenborn*, + Xiaohua Zhai*, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, + Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby* +- [Scaling Vision Transformers](https://arxiv.org/abs/2106.04560), by + Xiaohua Zhai*, Alexander Kolesnikov*, Neil Houlsby, and Lucas Beyer*\ + Resources: [config](big_vision/configs/proj/scaling_laws/train_vit_g.py). +- [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270), by + Andreas Steiner*, Alexander Kolesnikov*, Xiaohua Zhai*, Ross Wightman, + Jakob Uszkoreit, and Lucas Beyer* +- [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601), by + Ilya Tolstikhin*, Neil Houlsby*, Alexander Kolesnikov*, Lucas Beyer*, + Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner, + Daniel Keysers, Jakob Uszkoreit, Mario Lucic, Alexey Dosovitskiy\ + Resources: [config](big_vision/configs/mlp_mixer_i1k.py). +- [Better plain ViT baselines for ImageNet-1k](https://arxiv.org/abs/2205.01580), by + Lucas Beyer, Xiaohua Zhai, Alexander Kolesnikov\ + Resources: [config](big_vision/configs/vit_s16_i1k.py) +- [UViM: A Unified Modeling Approach for Vision with Learned Guiding Codes](https://arxiv.org/abs/2205.10337), by + Alexander Kolesnikov^*, André Susano Pinto^*, Lucas Beyer*, Xiaohua Zhai*, Jeremiah Harmsen*, Neil Houlsby*\ + Resources: [readme](big_vision/configs/proj/uvim/README.md), [configs](big_vision/configs/proj/uvim), [colabs](big_vision/configs/proj/uvim). +- [FlexiViT: One Model for All Patch Sizes](https://arxiv.org/abs/2212.08013), by + Lucas Beyer*, Pavel Izmailov*, Alexander Kolesnikov*, Mathilde Caron*, Simon + Kornblith*, Xiaohua Zhai*, Matthias Minderer*, Michael Tschannen*, Ibrahim + Alabdulmohsin*, Filip Pavetic*\ + Resources: [readme](big_vision/configs/proj/flexivit/README.md), [configs](big_vision/configs/proj/flexivit). +- [Dual PatchNorm](https://arxiv.org/abs/2302.01327), by Manoj Kumar, Mostafa Dehghani, Neil Houlsby. +- [Getting ViT in Shape: Scaling Laws for Compute-Optimal Model Design](https://arxiv.org/abs/2305.13035), by + Ibrahim Alabdulmohsin*, Xiaohua Zhai*, Alexander Kolesnikov, Lucas Beyer*. +- (partial) [Scaling Vision Transformers to 22 Billion Parameters](https://arxiv.org/abs/2302.05442), by + Mostafa Dehghani*, Josip Djolonga*, Basil Mustafa*, Piotr Padlewski*, Jonathan Heek*, *wow many middle authors*, Neil Houlsby*. +- (partial) [Finite Scalar Quantization: VQ-VAE Made Simple](https://arxiv.org/abs/2309.15505), by + Fabian Mentzer, David Minnen, Eirikur Agustsson, Michael Tschannen. +- [GIVT: Generative Infinite-Vocabulary Transformers](https://arxiv.org/abs/2312.02116), by + Michael Tschannen, Cian Eastwood, Fabian Mentzer.\ + Resources: [readme](big_vision/configs/proj/givt/README.md), [config](big_vision/configs/proj/givt/givt_imagenet2012.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/givt/givt_demo_colab.ipynb). +- [Unified Auto-Encoding with Masked Diffusion](https://arxiv.org/abs/2406.17688), by + Philippe Hansen-Estruch, Sriram Vishwanath, Amy Zhang, Manan Tomar. + + +### Multimodal research + +- [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991), by + Xiaohua Zhai*, Xiao Wang*, Basil Mustafa*, Andreas Steiner*, Daniel Keysers, + Alexander Kolesnikov, and Lucas Beyer*\ + Resources: [trainer](big_vision/trainers/proj/image_text/contrastive.py), [config](big_vision/configs/proj/image_text/lit_coco.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/lit.ipynb). +- [Image-and-Language Understanding from Pixels Only](https://arxiv.org/abs/2212.08045), by + Michael Tschannen, Basil Mustafa, Neil Houlsby\ + Resources: [readme](big_vision/configs/proj/clippo/README.md), [config](big_vision/configs/proj/clippo/train_clippo.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/clippo/clippo_colab.ipynb). +- [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343), by + Xiaohua Zhai*, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer*\ + Resources: [colab and models](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb), code TODO. +- [A Study of Autoregressive Decoders for Multi-Tasking in Computer Vision](https://arxiv.org/abs/2303.17376), by + Lucas Beyer*, Bo Wan*, Gagan Madan*, Filip Pavetic*, Andreas Steiner*, Alexander Kolesnikov, André Susano Pinto, Emanuele Bugliarello, Xiao Wang, Qihang Yu, Liang-Chieh Chen, Xiaohua Zhai*. +- [Image Captioners Are Scalable Vision Learners Too](https://arxiv.org/abs/2306.07915), by + Michael Tschannen*, Manoj Kumar*, Andreas Steiner*, Xiaohua Zhai, Neil Houlsby, Lucas Beyer*.\ + Resources: [readme](big_vision/configs/proj/cappa/README.md), [config](big_vision/configs/proj/cappa/pretrain.py), [model](big_vision/models/proj/cappa/cappa.py). +- [Three Towers: Flexible Contrastive Learning with Pretrained Image Models](https://arxiv.org/abs/2305.16999), by Jannik Kossen, Mark Collier, Basil Mustafa, Xiao Wang, Xiaohua Zhai, Lucas Beyer, Andreas Steiner, Jesse Berent, Rodolphe Jenatton, Efi Kokiopoulou. +- (partial) [PaLI: A Jointly-Scaled Multilingual Language-Image Model](https://arxiv.org/abs/2209.06794), by Xi Chen, Xiao Wang, Soravit Changpinyo, *wow so many middle authors*, Anelia Angelova, Xiaohua Zhai, Neil Houlsby, Radu Soricut. +- (partial) [PaLI-3 Vision Language Models: Smaller, Faster, Stronger](https://arxiv.org/abs/2310.09199), by Xi Chen, Xiao Wang, Lucas Beyer, Alexander Kolesnikov, Jialin Wu, Paul Voigtlaender, Basil Mustafa, Sebastian Goodman, Ibrahim Alabdulmohsin, Piotr Padlewski, Daniel Salz, Xi Xiong, Daniel Vlasic, Filip Pavetic, Keran Rong, Tianli Yu, Daniel Keysers, Xiaohua Zhai, Radu Soricut. +- [LocCa](https://arxiv.org/abs/2403.19596), by + Bo Wan, Michael Tschannen, Yongqin Xian, Filip Pavetic, Ibrahim Alabdulmohsin, Xiao Wang, André Susano Pinto, Andreas Steiner, Lucas Beyer, Xiaohua Zhai. +- [PaliGemma](https://arxiv.org/abs/2407.07726), + [PaliGemma 2](https://arxiv.org/abs/2412.03555), by *wow many authors*.\ +- Resources: [readme](big_vision/configs/proj/paligemma/README.md), + [model](big_vision/models/proj/paligemma/paligemma.py), + [transfer configs](big_vision/configs/proj/paligemma/transfers), + [datasets](big_vision/datasets), + [CountBenchQA](big_vision/datasets/countbenchqa/data/countbench_paired_questions.json). + +### Training + +- [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237), by + Lucas Beyer*, Xiaohua Zhai*, Amélie Royer*, Larisa Markeeva*, Rohan Anil, + and Alexander Kolesnikov*\ + Resources: [README](big_vision/configs/proj/distill/README.md), [trainer](big_vision/trainers/proj/distill/distill.py), [colab](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing). +- [Sharpness-Aware Minimization for Efficiently Improving Generalization](https://arxiv.org/abs/2010.01412), by + Pierre Foret, Ariel Kleiner, Hossein Mobahi, Behnam Neyshabur +- [Surrogate Gap Minimization Improves Sharpness-Aware Training](https://arxiv.org/abs/2203.08065), by Juntang Zhuang, Boqing Gong, Liangzhe Yuan, Yin Cui, Hartwig Adam, Nicha Dvornek, Sekhar Tatikonda, James Duncan and Ting Liu \ + Resources: [trainer](big_vision/trainers/proj/gsam/gsam.py), [config](big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py) [reproduced results](https://github.com/google-research/big_vision/pull/8#pullrequestreview-1078557411) +- [Tuning computer vision models with task rewards](https://arxiv.org/abs/2302.08242), by + André Susano Pinto*, Alexander Kolesnikov*, Yuge Shi, Lucas Beyer, Xiaohua Zhai. +- (partial) [VeLO: Training Versatile Learned Optimizers by Scaling Up](https://arxiv.org/abs/2211.09760) by + Luke Metz, James Harrison, C. Daniel Freeman, Amil Merchant, Lucas Beyer, James Bradbury, Naman Agrawal, Ben Poole, Igor Mordatch, Adam Roberts, Jascha Sohl-Dickstein. + +### Misc + +- [Are we done with ImageNet?](https://arxiv.org/abs/2006.07159), by + Lucas Beyer*, Olivier J. Hénaff*, Alexander Kolesnikov*, Xiaohua Zhai*, Aäron van den Oord*. +- [No Filter: Cultural and Socioeconomic Diversity in Contrastive Vision-Language Models](https://arxiv.org/abs/2405.13777), by + Angéline Pouget, Lucas Beyer, Emanuele Bugliarello, Xiao Wang, Andreas Peter Steiner, Xiaohua Zhai, Ibrahim Alabdulmohsin. + +# Codebase high-level organization and principles in a nutshell + +The main entry point is a trainer module, which typically does all the +boilerplate related to creating a model and an optimizer, loading the data, +checkpointing and training/evaluating the model inside a loop. We provide the +canonical trainer `train.py` in the root folder. Normally, individual projects +within `big_vision` fork and customize this trainer. + +All models, evaluators and preprocessing operations live in the corresponding +subdirectories and can often be reused between different projects. We encourage +compatible APIs within these directories to facilitate reusability, but it is +not strictly enforced, as individual projects may need to introduce their custom +APIs. + +We have a powerful configuration system, with the configs living in the +`configs/` directory. Custom trainers and modules can directly extend/modify +the configuration options. + +Project-specific code resides in the `.../proj/...` namespace. It is not always +possible to keep project-specific in sync with the core `big_vision` libraries, +Below we provide the [last known commit](#project-specific-commits) +for each project where the project code is expected to work. + +Training jobs are robust to interruptions and will resume seamlessly from the +last saved checkpoint (assuming a user provides the correct `--workdir` path). + +Each configuration file contains a comment at the top with a `COMMAND` snippet +to run it, and some hint of expected runtime and results. See below for more +details, but generally speaking, running on a GPU machine involves calling +`python -m COMMAND` while running on TPUs, including multi-host, involves + +``` +gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all + --command "bash big_vision/run_tpu.sh COMMAND" +``` + +See instructions below for more details on how to run `big_vision` code on a +GPU machine or Google Cloud TPU. + +By default we write checkpoints and logfiles. The logfiles are a list of JSON +objects, and we provide a short and straightforward [example colab to read +and display the logs and checkpoints](https://colab.research.google.com/drive/1R_lvV542WUp8Q2y8sbyooZOGCplkn7KI?usp=sharing). + +# Current and future contents + +The first release contains the core part of pre-training, transferring, and +evaluating classification models at scale on Cloud TPU VMs. + +We have since added the following key features and projects: +- Contrastive Image-Text model training and evaluation as in LiT and CLIP. +- Patient and consistent distillation. +- Scaling ViT. +- MLP-Mixer. +- UViM. + +Features and projects we plan to release in the near future, in no particular +order: +- ImageNet-21k in TFDS. +- Loading misc public models used in our publications (NFNet, MoCov3, DINO). +- Memory-efficient Polyak-averaging implementation. +- Advanced JAX compute and memory profiling. We are using internal tools for + this, but may eventually add support for the publicly available ones. + +We will continue releasing code of our future publications developed within +`big_vision` here. + +### Non-content + +The following exist in the internal variant of this codebase, and there is no +plan for their release: +- Regular regression tests for both quality and speed. They rely heavily on + internal infrastructure. +- Advanced logging, monitoring, and plotting of experiments. This also relies + heavily on internal infrastructure. However, we are open to ideas on this + and may add some in the future, especially if implemented in a + self-contained manner. +- Not yet published, ongoing research projects. + + +# GPU Setup + +We first discuss how to setup and run `big_vision` on a (local) GPU machine, +and then discuss the setup for Cloud TPUs. Note that data preparation step for +(local) GPU setup can be largely reused for the Cloud TPU setup. While the +instructions skip this for brevity, we highly recommend using a +[virtual environment](https://docs.python.org/3/library/venv.html) when +installing python dependencies. + +## Setting up python packages + +The first step is to checkout `big_vision` and install relevant python +dependencies: + +``` +git clone https://github.com/google-research/big_vision +cd big_vision/ +pip3 install --upgrade pip +pip3 install -r big_vision/requirements.txt +``` + +The latest version of `jax` library can be fetched as + +``` +pip3 install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``` + +You may need a different `jax` package, depending on CUDA and cuDNN libraries +installed on your machine. Please consult +[official jax documentation](https://github.com/google/jax#pip-installation-gpu-cuda) +for more information. + +## Preparing tfds data + +For unified and reproducible access to standard datasets we opted to use the +`tensorflow_datasets` (`tfds`) library. It requires each dataset to be +downloaded, preprocessed and then to be stored on a hard drive (or, if you use +"Google Cloud", preferably stored in a "GCP bucket".). + +Many datasets can be downloaded and preprocessed automatically when used +for the first time. Nevertheless, we intentionally disable this feature and +recommend doing dataset preparation step separately, ahead of the first run. It +will make debugging easier if problems arise and some datasets, like +`imagenet2012`, require manually downloaded data. + +Most of the datasets, e.g. `cifar100`, `oxford_iiit_pet` or `imagenet_v2` +can be fully automatically downloaded and prepared by running + +``` +cd big_vision/ +python3 -m big_vision.tools.download_tfds_datasets cifar100 oxford_iiit_pet imagenet_v2 +``` + +A full list of datasets is available at [this link](https://www.tensorflow.org/datasets/catalog/overview#all_datasets). + +Some datasets, like `imagenet2012` or `imagenet2012_real`, require the data to +be downloaded manually and placed into `$TFDS_DATA_DIR/downloads/manual/`, +which defaults to `~/tensorflow_datasets/downloads/manual/`. For example, for +`imagenet2012` and `imagenet2012_real` one needs to place the official +`ILSVRC2012_img_train.tar` and `ILSVRC2012_img_val.tar` files in that directory +and then run +`python3 -m big_vision.tools.download_tfds_datasets imagenet2012 imagenet2012_real` +(which may take ~1 hour). + +If you use `Google Cloud` and, TPUs in particular, you can then upload +the preprocessed data (stored in `$TFDS_DATA_DIR`) to +"Google Cloud Bucket" and use the bucket on any of your (TPU) virtual +machines to access the data. + +## Running on a GPU machine + +Finally, after installing all python dependencies and preparing `tfds` data, +the user can run the job using config of their choice, e.g. to train `ViT-S/16` +model on ImageNet data, one should run the following command: + +``` +python3 -m big_vision.train --config big_vision/configs/vit_s16_i1k.py --workdir workdirs/`date '+%m-%d_%H%M'` +``` + +or to train MLP-Mixer-B/16, run (note the `gpu8` config param that reduces the default batch size and epoch count): + +``` +python3 -m big_vision.train --config big_vision/configs/mlp_mixer_i1k.py:gpu8 --workdir workdirs/`date '+%m-%d_%H%M'` +``` + +# Cloud TPU VM setup + +## Create TPU VMs + +To create a single machine with 8 TPU cores, follow the following Cloud TPU JAX +document: +https://cloud.google.com/tpu/docs/run-calculation-jax + +To support large-scale vision research, more cores with multiple hosts are +recommended. Below we provide instructions on how to do it. + +First, create some useful variables, which we be reused: + +``` +export NAME= +export ZONE= +export GS_BUCKET_NAME= +``` + +The following command line will create TPU VMs with 32 cores, +4 hosts. + +``` +gcloud compute tpus tpu-vm create $NAME --zone $ZONE --accelerator-type v3-32 --version tpu-ubuntu2204-base +``` + +## Install `big_vision` on TPU VMs + +Fetch the `big_vision` repository, copy it to all TPU VM hosts, and install +dependencies. + +``` +git clone https://github.com/google-research/big_vision +gcloud compute tpus tpu-vm scp --recurse big_vision/big_vision $NAME: --zone=$ZONE --worker=all +gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "bash big_vision/run_tpu.sh" +``` + +## Download and prepare TFDS datasets + +We recommend preparing `tfds` data locally as described above and then uploading +the data to `Google Cloud` bucket. However, if you prefer, the datasets which +do not require manual downloads can be prepared automatically using a TPU +machine as described below. Note that TPU machines have only 100 GB of disk +space, and multihost TPU slices do not allow for external disks to be attached +in a write mode, so the instructions below may not work for preparing large +datasets. As yet another alternative, we provide instructions +[on how to prepare `tfds` data on CPU-only GCP machine](#preparing-tfds-data-on-a-standalone-gcp-cpu-machine). + +Specifically, the seven TFDS datasets used during evaluations will be generated +under `~/tensorflow_datasets` on TPU machine with this command: + +``` +gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "TFDS_DATA_DIR=~/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets cifar10 cifar100 oxford_iiit_pet oxford_flowers102 cars196 dtd uc_merced" +``` + +You can then copy the datasets to GS bucket, to make them accessible to all TPU workers. + +``` +gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "rm -r ~/tensorflow_datasets/downloads && gsutil cp -r ~/tensorflow_datasets gs://$GS_BUCKET_NAME" +``` + +If you want to integrate other public or custom datasets, i.e. imagenet2012, +please follow [the official guideline](https://www.tensorflow.org/datasets/catalog/overview). + +## Pre-trained models + +For the full list of pre-trained models check out the `load` function defined in +the same module as the model code. And for example config on how to use these +models, see `configs/transfer.py`. + +## Run the transfer script on TPU VMs + +The following command line fine-tunes a pre-trained `vit-i21k-augreg-b/32` model +on `cifar10` dataset. + +``` +gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03" +``` + +## Run the train script on TPU VMs + +To train your own big_vision models on a large dataset, +e.g. `imagenet2012` ([prepare the TFDS dataset](https://www.tensorflow.org/datasets/catalog/imagenet2012)), +run the following command line. + +``` +gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/bit_i1k.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`" +``` + +## FSDP training. + +`big_vision` supports flexible parameter and model sharding strategies. +Currently, we support a popular FSDP sharding via a simple config change, see [this config example](big_vision/configs/transfer.py). +For example, to run FSDP finetuning of a pretrained ViT-L model, run the following command (possible adjusting batch size depending on your hardware): + +``` +gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-l/16,dataset=oxford_iiit_pet,crop=resmall_crop,fsdp=True,batch_size=256 --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03" +``` + +## Image-text training with SigLIP. + +A minimal example that uses public `coco` captions data: + +``` +gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.trainers.proj.image_text.siglip --config big_vision/configs/proj/image_text/siglip_lit_coco.py --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%Y-%m-%d_%H%M'`" +``` + + + +## Sometimes useful gcloud commands + +- Destroy the TPU machines: `gcloud compute tpus tpu-vm delete $NAME --zone $ZONE` +- Remove all big_vision-related folders on all hosts: `gcloud compute tpus tpu-vm ssh $NAME --zone $ZONE --worker=all --command 'rm -rf ~/big_vision ~/bv_venv'` + +## Preparing `tfds` data on a standalone GCP CPU machine. + +First create a new machine and a disk (feel free to adjust exact machine type and disk settings/capacity): + +``` +export NAME_CPU_HOST= +export NAME_DISK= +gcloud compute instances create $NAME_CPU_HOST --machine-type c3-standard-22 --zone $ZONE --image-family ubuntu-2204-lts --image-project ubuntu-os-cloud +gcloud compute disks create $NAME_DISK --size 1000GB --zone $ZONE --type pd-balanced +``` + +Now attach the disk to the newly create machine: + +``` +gcloud compute instances attach-disk $NAME_CPU_HOST --disk $NAME_DISK --zone $ZONE +``` + +Next, `ssh` to the machine `gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE` and +[follow instructions to format and mount the disk](https://cloud.google.com/compute/docs/disks/format-mount-disk-linux). +Let's assume it was mounted to `/mnt/disks/tfds`. + +Almost there, now clone and set up `big_vision`: + +``` +gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "git clone https://github.com/google-research/big_vision.git && cd big_vision && sh big_vision/run_tpu.sh" +``` + +Finally, prepare the dataset (e.g. `coco_captions`) using the utility script and +copy the result to you google cloud bucket: + +``` +gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "cd big_vision && TFDS_DATA_DIR=/mnt/disks/tfds/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets coco_captions" +gcloud compute ssh $NAME_CPU_HOST --zone=$ZONE --command "rm -rf /mnt/disks/tfds/tensorflow_datasets/downloads && gsutil cp -r /mnt/disks/tfds/tensorflow_datasets gs://$GS_BUCKET_NAME" +``` + + +# ViT baseline + +We provide a well-tuned ViT-S/16 baseline in the config file named +`vit_s16_i1k.py`. It achieves 76.5% accuracy on ImageNet validation split in +90 epochs of training, being a strong and simple starting point for research +on the ViT models. + +Please see our [arXiv note](https://arxiv.org/abs/2205.01580) for more details +and if this baseline happens to by useful for your research, consider citing + +``` +@article{vit_baseline, + url = {https://arxiv.org/abs/2205.01580}, + author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander}, + title = {Better plain ViT baselines for ImageNet-1k}, + journal={arXiv preprint arXiv:2205.01580}, + year = {2022}, +} +``` + +# Project specific commits + +The last known commit where the specific project code is expected to work. The +core code and configs are expected to work at head. + +| Project | Commit | +|------------|-----------------------------------------------------------------------------------------------| +| UViM | https://github.com/google-research/big_vision/commit/21bd6ebe253f070f584d8b777ad76f4abce51bef | +| image_text | https://github.com/google-research/big_vision/commit/8921d5141504390a8a4f7b2dacb3b3c042237290 | +| distill | https://github.com/google-research/big_vision/commit/2f3f493af048dbfd97555ff6060f31a0e686f17f | +| GSAM | WIP | +| CLIPPO | https://github.com/google-research/big_vision/commit/fd2d3bd2efc9d89ea959f16cd2f58ae8a495cd44 | +| CapPa | https://github.com/google-research/big_vision/commit/7ace659452dee4b68547575352c022a2eef587a5 | +| GIVT | https://github.com/google-research/big_vision/commit/0cb70881dd33b3343b769347dc19793c4994b8cb | + +# Citing the codebase + +If you found this codebase useful for your research, please consider using +the following BibTEX to cite it: + +``` +@misc{big_vision, + author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander}, + title = {Big Vision}, + year = {2022}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/google-research/big_vision}} +} +``` + +# Disclaimer + +This is not an official Google Product. + +# License + +Unless explicitly noted otherwise, everything in the big_vision codebase +(including models and colabs) is released under the Apache2 license. +See the LICENSE file for the full license text. diff --git a/big_vision/__init__.py b/big_vision/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/__pycache__/__init__.cpython-310.pyc b/big_vision/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd8350cbd6c74dc44e34b6366cf89b4b9706afd1 Binary files /dev/null and b/big_vision/__pycache__/__init__.cpython-310.pyc differ diff --git a/big_vision/__pycache__/utils.cpython-310.pyc b/big_vision/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30f6c73dfa925c37278708b86b5563be17cc5070 Binary files /dev/null and b/big_vision/__pycache__/utils.cpython-310.pyc differ diff --git a/big_vision/configs/__init__.py b/big_vision/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/configs/bit_i1k.py b/big_vision/configs/bit_i1k.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd53c318b108bab483923a95e8d3c0df42d709d --- /dev/null +++ b/big_vision/configs/bit_i1k.py @@ -0,0 +1,102 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Pre-training BiT on ILSVRC-2012 as in https://arxiv.org/abs/1912.11370 + +Run training of a BiT-ResNet-50x1 variant, which takes ~32min on v3-128: + +big_vision.train \ + --config big_vision/configs/bit_i1k.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ + --config.model.depth 50 --config.model.width 1 +""" + +# from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + + +def get_config(runlocal=False): + """Config for training on ImageNet-1k.""" + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 90 + config.num_classes = 1000 + config.loss = 'softmax_xent' + + config.input = dict() + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 4096 + config.input.cache_raw = True # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 # Per host. + + pp_common = '|onehot(1000, key="{lbl}", key_result="labels")' + pp_common += '|value_range(-1, 1)|keep("image", "labels")' + config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label') + pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + # Model section + config.model_name = 'bit' + config.model = dict( + depth=50, # You can also pass e.g. [3, 5, 10, 2] + width=1.0, + ) + + # Optimizer section + config.optax_name = 'big_vision.momentum_hp' + config.grad_clip_norm = 1.0 + + # linear scaling rule. Don't forget to sweep if sweeping batch_size. + config.wd = (1e-4 / 256) * config.input.batch_size + config.lr = (0.1 / 256) * config.input.batch_size + config.schedule = dict(decay_type='cosine', warmup_steps=1000) + + # Eval section + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=1000, # Very fast O(seconds) so it's fine to run it often. + cache='final_data', + ) + config.evals = {} + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') + config.evals.real.pp_fn = pp_eval.format(lbl='real_label') + + # config.evals.fewshot = get_fewshot_lsr(runlocal=runlocal) + # config.evals.fewshot.log_steps = 1000 + + if runlocal: + config.input.batch_size = 32 + config.input.cache_raw = False + config.input.shuffle_buffer_size = 100 + + local_eval = config.evals.val + config.evals = {'val': local_eval} + config.evals.val.cache = 'none' + + return config \ No newline at end of file diff --git a/big_vision/configs/bit_i21k.py b/big_vision/configs/bit_i21k.py new file mode 100644 index 0000000000000000000000000000000000000000..c42342e9ab8ff513211954efab79dd4309fbe101 --- /dev/null +++ b/big_vision/configs/bit_i21k.py @@ -0,0 +1,85 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""A config for pre-training BiT on ImageNet-21k. + +This config relies on the Imagenet-21k tfds dataset, which is not yet +available publicly in TFDS. We intend to add the dataset to public TFDS soon, +and this config will then be runnable. +""" + +from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + + +def get_config(): + """Config for training on imagenet-21k.""" + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 90 + config.num_classes = 21843 + config.init_head_bias = -10.0 + config.loss = 'sigmoid_xent' + + config.input = dict() + config.input.data = dict( + name='imagenet21k', + split='full[51200:]', + ) + config.input.batch_size = 4096 + config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. + + pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")' + pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}') + pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"') + config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k + pp_eval = 'decode|resize_small(256)|central_crop(224)' + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + # Model section + config.model_name = 'bit_paper' + config.model = dict(depth=50, width=1.0) + + # Optimizer section + config.optax_name = 'big_vision.momentum_hp' + config.grad_clip_norm = 1.0 + + # linear scaling rule. Don't forget to sweep if sweeping batch_size. + config.lr = (0.03 / 256) * config.input.batch_size + config.wd = (3e-5 / 256) * config.input.batch_size + config.schedule = dict(decay_type='cosine', warmup_steps=5000) + + # Evaluations on i21k itself. + def eval_i21k(split): + return dict( + type='classification', + data={**config.input.data, 'split': split}, + pp_fn=pp_eval + pp_common_i21k, + loss_name=config.loss, + log_steps=1000, # Very fast O(seconds) so it's fine to run it often. + ) + config.evals = {} + config.evals.test = eval_i21k('full[:25_600]') + config.evals.val = eval_i21k('full[25_600:51_200]') + config.evals.train = eval_i21k('full[51_200:76_800]') + + # Few-shot evaluators + config.evals.fewshot = get_fewshot_lsr() + config.evals.fewshot.log_steps = 25_000 + + return config \ No newline at end of file diff --git a/big_vision/configs/common.py b/big_vision/configs/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c1628c3ccaa554eb5d2a39e2317fb06953542a6d --- /dev/null +++ b/big_vision/configs/common.py @@ -0,0 +1,188 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A few things commonly used across A LOT of config files.""" + +import string + +import ml_collections as mlc + + +def input_for_quicktest(config_input, quicktest): + if quicktest: + config_input.batch_size = 8 + config_input.shuffle_buffer_size = 10 + config_input.cache_raw = False + + +def parse_arg(arg, lazy=False, **spec): + """Makes ConfigDict's get_config single-string argument more usable. + + Example use in the config file: + + import big_vision.configs.common as bvcc + def get_config(arg): + arg = bvcc.parse_arg(arg, + res=(224, int), + runlocal=False, + schedule='short', + ) + + # ... + + config.shuffle_buffer = 250_000 if not arg.runlocal else 50 + + Ways that values can be passed when launching: + + --config amazing.py:runlocal,schedule=long,res=128 + --config amazing.py:res=128 + --config amazing.py:runlocal # A boolean needs no value for "true". + --config amazing.py:runlocal=False # Explicit false boolean. + --config amazing.py:128 # The first spec entry may be passed unnamed alone. + + Uses strict bool conversion (converting 'True', 'true' to True, and 'False', + 'false', '' to False). + + Args: + arg: the string argument that's passed to get_config. + lazy: allow lazy parsing of arguments, which are not in spec. For these, + the type is auto-extracted in dependence of most complex possible type. + **spec: the name and default values of the expected options. + If the value is a tuple, the value's first element is the default value, + and the second element is a function called to convert the string. + Otherwise the type is automatically extracted from the default value. + + Returns: + ConfigDict object with extracted type-converted values. + """ + # Normalize arg and spec layout. + arg = arg or '' # Normalize None to empty string + spec = {k: get_type_with_default(v) for k, v in spec.items()} + + result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only. + + # Expand convenience-cases for a single parameter without = sign. + if arg and ',' not in arg and '=' not in arg: + # (think :runlocal) If it's the name of sth in the spec (or there is no + # spec), it's that in bool. + if arg in spec or not spec: + arg = f'{arg}=True' + # Otherwise, it is the value for the first entry in the spec. + else: + arg = f'{list(spec.keys())[0]}={arg}' + # Yes, we rely on Py3.7 insertion order! + + # Now, expand the `arg` string into a dict of keys and values: + raw_kv = {raw_arg.split('=')[0]: + raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True' + for raw_arg in arg.split(',') if raw_arg} + + # And go through the spec, using provided or default value for each: + for name, (default, type_fn) in spec.items(): + val = raw_kv.pop(name, None) + result[name] = type_fn(val) if val is not None else default + + if raw_kv: + if lazy: # Process args which are not in spec. + for k, v in raw_kv.items(): + result[k] = autotype(v) + else: + raise ValueError(f'Unhandled config args remain: {raw_kv}') + + return result + + +def get_type_with_default(v): + """Returns (v, string_to_v_type) with lenient bool parsing.""" + # For bool, do safe string conversion. + if isinstance(v, bool): + def strict_bool(x): + assert x.lower() in {'true', 'false', ''} + return x.lower() == 'true' + return (v, strict_bool) + # If already a (default, type) tuple, use that. + if isinstance(v, (tuple, list)): + assert len(v) == 2 and isinstance(v[1], type), ( + 'List or tuple types are currently not supported because we use `,` as' + ' dumb delimiter. Contributions (probably using ast) welcome. You can' + ' unblock by using a string with eval(s.replace(";", ",")) or similar') + return (v[0], v[1]) + # Otherwise, derive the type from the default value. + return (v, type(v)) + + +def autotype(x): + """Auto-converts string to bool/int/float if possible.""" + assert isinstance(x, str) + if x.lower() in {'true', 'false'}: + return x.lower() == 'true' # Returns as bool. + try: + return int(x) # Returns as int. + except ValueError: + try: + return float(x) # Returns as float. + except ValueError: + return x # Returns as str. + + +def pack_arg(**kw): + """Packs key-word args as a string to be parsed by `parse_arg()`.""" + for v in kw.values(): + assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}" + return ','.join([f'{k}={v}' for k, v in kw.items()]) + + +def arg(**kw): + """Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg.""" + return {'config_arg': pack_arg(**kw), **kw} + + +def _get_field_ref(config_dict, field_name): + path = field_name.split('.') + for field in path[:-1]: + config_dict = getattr(config_dict, field) + return config_dict.get_ref(path[-1]) + + +def format_str(format_string, config): + """Format string with reference fields from config. + + This makes it easy to build preprocess strings that contain references to + fields tha are edited after. E.g.: + + ``` + config = mlc.ConficDict() + config.res = (256, 256) + config.pp = bvcc.format_str('resize({res})', config) + ... + # if config.res is modified (e.g. via sweeps) it will propagate to pp field: + config.res = (512, 512) + assert config.pp == 'resize((512, 512))' + ``` + + Args: + format_string: string to format with references. + config: ConfigDict to get references to format the string. + + Returns: + A reference field which renders a string using references to config fields. + """ + output = '' + parts = string.Formatter().parse(format_string) + for (literal_text, field_name, format_spec, conversion) in parts: + assert not format_spec and not conversion + output += literal_text + if field_name: + output += _get_field_ref(config, field_name).to_str() + return output diff --git a/big_vision/configs/common_fewshot.py b/big_vision/configs/common_fewshot.py new file mode 100644 index 0000000000000000000000000000000000000000..c430383639adcf6103e0976b190eda5b2740321a --- /dev/null +++ b/big_vision/configs/common_fewshot.py @@ -0,0 +1,60 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Most common few-shot eval configuration.""" + +import ml_collections as mlc + + +def get_fewshot_lsr(target_resolution=224, resize_resolution=256, + runlocal=False, pp=None, **kw): + """Returns a standard-ish fewshot eval configuration.""" + kw.setdefault('representation_layer', 'pre_logits') + kw.setdefault('shots', (1, 5, 10, 25)) + kw.setdefault('l2_reg', 2.0 ** 10) + kw.setdefault('num_seeds', 3) + kw.setdefault('prefix', '') # No prefix as we already use a/ z/ and zz/ + + # Backward-compatible default: + if not any(f'log_{x}' in kw for x in ['steps', 'percent', 'examples', 'epochs']): # pylint: disable=line-too-long + kw['log_steps'] = 25_000 + + config = mlc.ConfigDict(kw) + config.type = 'fewshot_lsr' + config.datasets = { + 'caltech': ('caltech101', 'train', 'test'), # copybara:srtip + 'cars': ('cars196:2.1.0', 'train', 'test'), + 'cifar100': ('cifar100', 'train', 'test'), + 'dtd': ('dtd', 'train', 'test'), + # The first 65000 ImageNet samples have at least 30 shots per any class. + # Commented out by default because needs manual download. + # 'imagenet': ('imagenet2012', 'train[:65000]', 'validation'), + 'pets': ('oxford_iiit_pet', 'train', 'test'), + 'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'), + } if not runlocal else { + 'pets': ('oxford_iiit_pet', 'train', 'test'), + } + + pp = pp or '|'.join([ + 'decode', + f'resize({resize_resolution})', + f'central_crop({target_resolution})', + 'value_range(-1,1)' + ]) + pp += '|keep("image", "label")' + config.pp_train = pp + config.pp_eval = pp + config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)] + + return config diff --git a/big_vision/configs/load_and_eval.py b/big_vision/configs/load_and_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..7e102b0f561f2cc6ec59439f831e9e289488b7b0 --- /dev/null +++ b/big_vision/configs/load_and_eval.py @@ -0,0 +1,143 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pytype: disable=not-writable,attribute-error +# pylint: disable=line-too-long,missing-function-docstring +r"""A config to load and eval key model using the core train.py. + +The runtime varies widely depending on the model, but each one should reproduce +the corresponding paper's numbers. +This configuration makes use of the "arg" to get_config to select which model +to run, so a few examples are given below: + +Run and evaluate a BiT-M ResNet-50x1 model that was transferred to i1k: + +big_vision.train \ + --config big_vision/configs/load_and_eval.py:name=bit_paper,batch_size=8 \ + --config.model_init M-imagenet2012 --config.model.width 1 --config.model.depth 50 + +Run and evaluate the recommended ViT-B/32 from "how to train your vit" paper: + +big_vision.train \ + --config big_vision/configs/load_and_eval.py:name=vit_i21k,batch_size=8 \ + --config.model.variant B/32 --config.model_init howto-i21k-B/32 +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.common_fewshot import get_fewshot_lsr + + +def eval_only(config, batch_size, spec_for_init): + """Set a few configs that turn trainer into (almost) eval-only.""" + config.total_steps = 0 + config.input = {} + config.input.batch_size = batch_size + config.input.data = dict(name='bv:dummy', spec=spec_for_init) + config.optax_name = 'identity' + config.lr = 0.0 + + config.mesh = [('data', -1)] + config.sharding_strategy = [('params/.*', 'fsdp(axis="data")')] + config.sharding_rules = [('act_batch', ('data',))] + + return config + + +def get_config(arg=''): + config = bvcc.parse_arg(arg, name='bit_paper', batch_size=4) + + # Make the config eval-only by setting some dummies. + eval_only(config, config.batch_size, spec_for_init=dict( + image=dict(shape=(224, 224, 3), dtype='float32'), + )) + + config.evals = dict(fewshot=get_fewshot_lsr()) + + # Just calls the function with the name given as `config`. + # Could also be a giant if-block if you're into that kind of thing. + globals()[config.name](config) + return config + + +def bit_paper(config): + config.num_classes = 1000 + + config.model_name = 'bit_paper' + config.model_init = 'M-imagenet2012' # M = i21k, -imagenet2012 = fine-tuned + config.model = dict(width=1, depth=50) + + def get_eval(split, lbl, dataset='imagenet2012_real'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + loss_name='softmax_xent', + cache='none', # Only run once, on low-mem machine. + pp_fn=( + 'decode|resize(384)|value_range(-1, 1)' + f'|onehot(1000, key="{lbl}", key_result="labels")' + '|keep("image", "labels")' + ), + ) + config.evals.test = get_eval('validation', 'original_label') + config.evals.real = get_eval('validation', 'real_label') + config.evals.v2 = get_eval('test', 'label', 'imagenet_v2') + + +def vit_i1k(config): + config.num_classes = 1000 + + config.model_name = 'vit' + config.model_init = '' # Will be set in sweep. + config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d', + rep_size=True) + + config.evals.val = dict( + type='classification', + data=dict(name='imagenet2012', split='validation'), + pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")', + loss_name='softmax_xent', + cache='none', # Only run once, on low-mem machine. + ) + + +def mlp_mixer_i1k(config): + config.num_classes = 1000 + + config.model_name = 'mlp_mixer' + config.model_init = '' # Will be set in sweep. + config.model = dict(variant='L/16') + + config.evals.val = dict( + type='classification', + data=dict(name='imagenet2012', split='validation'), + pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")', + loss_name='softmax_xent', + cache='none', # Only run once, on low-mem machine. + ) + + +def vit_i21k(config): + config.num_classes = 21843 + + config.model_name = 'vit' + config.model_init = '' # Will be set in sweep. + config.model = dict(variant='B/32', pool_type='tok') + + config.evals.val = dict( + type='classification', + data=dict(name='imagenet21k', split='full[:51200]'), + pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(21843)|keep("image", "labels")', + loss_name='sigmoid_xent', + cache='none', # Only run once, on low-mem machine. + ) diff --git a/big_vision/configs/mlp_mixer_i1k.py b/big_vision/configs/mlp_mixer_i1k.py new file mode 100644 index 0000000000000000000000000000000000000000..8afe9abfd31f4ecb4e53466ea3e2b2794e8af7e7 --- /dev/null +++ b/big_vision/configs/mlp_mixer_i1k.py @@ -0,0 +1,120 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""A config for training MLP-Mixer-B/16 model on ILSVRC-2012 ("ImageNet-1k"). + +Achieves 76.3% top-1 accuracy on the test split in 2h11m on TPU v3-128 +with 300 epochs. A shorter 60 epochs run is expected to get to 70.5% in 27m. + +big_vision.train \ + --config big_vision/configs/mlp_mixer_i1k.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ +""" + +from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + + +def get_config(mode=None): + """Config for training Mixer on i1k.""" + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 300 + config.num_classes = 1000 + config.loss = 'sigmoid_xent' + config.init_head_bias = -6.9 + + config.input = dict() + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 4096 + config.input.cache_raw = True # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 + + config.input.pp = ( + 'decode_jpeg_and_inception_crop(224)' + '|flip_lr' + '|randaug(2,15)' + '|value_range(-1, 1)' + '|onehot(1000, key="label", key_result="labels")' + '|keep("image", "labels")' + ) + pp_eval = ( + 'decode' + '|resize_small(256)|central_crop(224)' + '|value_range(-1, 1)' + '|onehot(1000, key="{lbl}", key_result="labels")' + '|keep("image", "labels")' + ) + + # To continue using the near-defunct randaug op. + config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug'] + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + config.prefetch_to_device = 2 + + # Model section + config.model_name = 'mlp_mixer' + config.model = dict() + config.model.variant = 'B/16' + config.model.stoch_depth = 0.1 + + config.mixup = dict(fold_in=None, p=0.5) + + # Optimizer section + config.optax_name = 'scale_by_adam' + config.grad_clip_norm = 1. + + config.lr = 0.001 + config.wd = 1e-4 + config.schedule = dict( + decay_type='linear', + warmup_steps=10_000, + linear_end=1e-5, + ) + + # Eval section + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=2500, # Very fast O(seconds) so it's fine to run it often. + cache_final=mode != 'gpu8', + ) + config.evals = {} + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') + config.evals.real.pp_fn = pp_eval.format(lbl='real_label') + + config.fewshot = get_fewshot_lsr() + + if mode == 'gpu8': + config.total_epochs = 60 + config.input.batch_size = 512 + config.input.cache_raw = False + if mode == 'regression_test': + config.total_epochs = 60 + + return config diff --git a/big_vision/configs/transfer.py b/big_vision/configs/transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee64e43b274bdf5d462c7e3f3d9b5fc085c1796 --- /dev/null +++ b/big_vision/configs/transfer.py @@ -0,0 +1,186 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long,missing-function-docstring +r"""A config for transferring vit-augreg. + +Best HP selected on (mini)val, expected test results (repeated 5 times): + +ViT-Augreg-B/32: + Dataset, crop, learning rate, mean (%), range (%) + - ImageNet, inception_crop, 0.03, 83.27, [83.22...83.33] + - Cifar10, resmall_crop, 0.003, 98.55, [98.46...98.6] + - Cifar100, resmall_crop, 0.01, 91.35, [91.09...91.62] + - Pets, inception_crop, 0.003, 93.78, [93.62...94.00] + - Flowers, inception_crop, 0.003, 99.43, [99.42...99.45] + + +Command to run: +big_vision.train \ + --config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop \ + --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03 +""" + +import big_vision.configs.common as bvcc +import ml_collections as mlc + + +def _set_model(config, model): + """Load pre-trained models: vit or bit.""" + # Reset the head to init (of zeros) when transferring. + config.model_load = dict(dont_load=['head/kernel', 'head/bias']) + + if model == 'vit-i21k-augreg-b/32': + # Load "recommended" upstream B/32 from https://arxiv.org/abs/2106.10270 + config.model_name = 'vit' + config.model_init = 'howto-i21k-B/32' + config.model = dict(variant='B/32', pool_type='tok') + elif model == 'vit-i21k-augreg-l/16': + config.model_name = 'vit' + config.model_init = 'howto-i21k-L/16' + config.model = dict(variant='L/16', pool_type='tok') + elif model == 'vit-s16': + config.model_name = 'vit' + config.model_init = 'i1k-s16-300ep' + config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d', + rep_size=True) + elif model == 'bit-m-r50x1': + config.model_name = 'bit_paper' + config.model_init = 'M' + config.model = dict(depth=50, width=1) + else: + raise ValueError(f'Unknown model: {model}, please define customized model.') + + +def _set_dataset(config, dataset, crop='inception_crop', h_res=448, l_res=384): + if dataset == 'cifar10': + _set_task(config, 'cifar10', 'train[:98%]', 'train[98%:]', 'test', 10, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res) + elif dataset == 'cifar100': + _set_task(config, 'cifar100', 'train[:98%]', 'train[98%:]', 'test', 100, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res) + elif dataset == 'imagenet2012': + _set_task(config, 'imagenet2012', 'train[:99%]', 'train[99%:]', 'validation', 1000, steps=20_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res) + _set_imagenet_variants(config) + elif dataset == 'oxford_iiit_pet': + _set_task(config, 'oxford_iiit_pet', 'train[:90%]', 'train[90%:]', 'test', 37, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res) + elif dataset == 'oxford_flowers102': + _set_task(config, 'oxford_flowers102', 'train[:90%]', 'train[90%:]', 'test', 102, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res) + else: + raise ValueError( + f'Unknown dataset: {dataset}, please define customized dataset.') + + +def _set_task(config, dataset, train, val, test, n_cls, + steps=20_000, warmup=500, lbl='label', crop='resmall_crop', + flip=True, h_res=448, l_res=384): + """Vision task with val and test splits.""" + config.total_steps = steps + config.schedule = dict( + warmup_steps=warmup, + decay_type='cosine', + ) + + config.input.data = dict(name=dataset, split=train) + pp_common = ( + '|value_range(-1, 1)|' + f'onehot({n_cls}, key="{lbl}", key_result="labels")|' + 'keep("image", "labels")' + ) + + if crop == 'inception_crop': + pp_train = f'decode|inception_crop({l_res})' + elif crop == 'resmall_crop': + pp_train = f'decode|resize_small({h_res})|random_crop({l_res})' + elif crop == 'resize_crop': + pp_train = f'decode|resize({h_res})|random_crop({l_res})' + else: + raise ValueError(f'Unknown crop: {crop}. Must be one of: ' + 'inception_crop, resmall_crop, resize_crop') + if flip: + pp_train += '|flip_lr' + config.input.pp = pp_train + pp_common + + pp = f'decode|resize_small({h_res})|central_crop({l_res})' + pp_common + config.num_classes = n_cls + + def get_eval(split): + return dict( + type='classification', + data=dict(name=dataset, split=split), + loss_name='softmax_xent', + log_steps=100, + pp_fn=pp, + ) + config.evals = dict(val=get_eval(val), test=get_eval(test)) + + +def _set_imagenet_variants(config, h_res=448, l_res=384): + """Evaluation tasks on ImageNet variants: v2 and real.""" + pp = (f'decode|resize_small({h_res})|central_crop({l_res})' + '|value_range(-1, 1)|onehot(1000, key="{lbl}", key_result="labels")|' + 'keep("image", "labels")' + ) + + # Special-case rename for i1k (val+test -> minival+val) + config.evals.minival = config.evals.val + config.evals.val = config.evals.test + # NOTE: keep test == val for convenience in subsequent analysis. + + config.evals.real = dict(type='classification') + config.evals.real.data = dict(name='imagenet2012_real', split='validation') + config.evals.real.pp_fn = pp.format(lbl='real_label') + config.evals.real.loss_name = config.loss + config.evals.real.log_steps = 100 + + config.evals.v2 = dict(type='classification') + config.evals.v2.data = dict(name='imagenet_v2', split='test') + config.evals.v2.pp_fn = pp.format(lbl='label') + config.evals.v2.loss_name = config.loss + config.evals.v2.log_steps = 100 + + +def get_config(arg=None): + """Config for adaptation.""" + arg = bvcc.parse_arg(arg, model='vit', dataset='cifar10', crop='resmall_crop', + h_res=448, l_res=384, batch_size=512, fsdp=False, + runlocal=False) + config = mlc.ConfigDict() + + config.input = {} + config.input.batch_size = arg.batch_size if not arg.runlocal else 8 + config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 100 + + config.log_training_steps = 10 + config.ckpt_steps = 1000 + config.ckpt_timeout = 600 + + # Optimizer section + config.optax_name = 'big_vision.momentum_hp' + config.grad_clip_norm = 1.0 + config.wd = None # That's our default, but just being explicit here! + config.loss = 'softmax_xent' + config.lr = 0.01 + config.mixup = dict(p=0.0) + + config.seed = 0 + + _set_dataset(config, arg.dataset, arg.crop, arg.h_res, arg.l_res) + + _set_model(config, arg.model) + if arg.fsdp: + config.mesh = [('data', -1)] + config.sharding_strategy = [('.*', 'fsdp(axis="data")')] + config.sharding_rules = [('act_batch', ('data',))] + config.model.scan = True + + return config \ No newline at end of file diff --git a/big_vision/configs/vit_i1k.py b/big_vision/configs/vit_i1k.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6dd8d18ba723175a7a0cf887352e3023a11ccc --- /dev/null +++ b/big_vision/configs/vit_i1k.py @@ -0,0 +1,177 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Pre-training ViT on ILSVRC-2012 as in https://arxiv.org/abs/2106.10270 + +This config does NOT include regularization (dropout, stochastic depth), which +was shown to help with B/32, B/16, L/16 models in the paper (Figure 4). + +This configuration makes use of the "arg" to get_config to select which model +to run, so a few examples are given below: + +Run training of a B/16 model: + +big_vision.train \ + --config big_vision/configs/vit_i1k.py:variant=B/16 \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` + +Run training of a B/32 model with custom aug-strenght and 300ep: + +big_vision.train \ + --config big_vision/configs/vit_i1k.py:variant=B/32,aug=light1 \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ + --config.total_epochs 300 +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + +MIXUP_DEF = { + 'none': dict(p=0.0, fold_in=None), + 'light1': dict(p=0.0, fold_in=None), + 'light2': dict(p=0.2, fold_in=None), + 'medium1': dict(p=0.2, fold_in=None), + 'medium2': dict(p=0.5, fold_in=None), + 'strong1': dict(p=0.5, fold_in=None), + 'strong2': dict(p=0.8, fold_in=None), +} + +RANDAUG_DEF = { + 'none': '', + 'light1': 'randaug(2,0)', # Actually not nothing! + 'light2': 'randaug(2,10)', + 'medium1': 'randaug(2,15)', + 'medium2': 'randaug(2,15)', + 'strong1': 'randaug(2,20)', + 'strong2': 'randaug(2,20)', +} + + +def get_config(arg=None): + """Config for training.""" + arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug='') + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 300 + config.num_classes = 1000 + config.loss = 'sigmoid_xent' + config.init_head_bias = -6.9 + + # If this gives a KeyError, lookup Fig4 of the paper and add an entry. + # Note, this here is a good average between 30ep and 300ep, sometimes you coud + # find a slightly better setting for either of them. + aug_setting = arg.aug or { + 'Ti/16': 'light1', + 'S/32': 'medium1', + 'S/16': 'medium2', + 'B/32': 'medium2', + 'B/16': 'medium2', + 'L/16': 'medium2', + }[arg.variant] + + config.input = dict() + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 4096 + config.input.cache = 'raw_data' if arg.runlocal else 'none' # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 + + pp_common = ( + '|value_range(-1, 1)' + '|onehot(1000, key="{lbl}", key_result="labels")' + '|keep("image", "labels")' + ) + config.input.pp = ( + 'decode_jpeg_and_inception_crop(224)|flip_lr|' + + RANDAUG_DEF[aug_setting] + + pp_common.format(lbl='label') + ) + pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common + + # To continue using the near-defunct randaug op. + config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug'] + + # Aggressive pre-fetching because our models here are small, so we not only + # can afford it, but we also need it for the smallest models to not be + # bottle-necked by the input pipeline. Play around with it for -L models tho. + config.input.prefetch = 8 + config.prefetch_to_device = 4 + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + # Model section + config.model_name = 'vit' + config.model = dict( + variant=arg.variant, + rep_size=True, + pool_type='tok', + ) + + # Optimizer section + config.grad_clip_norm = 1.0 + config.optax_name = 'scale_by_adam' + config.optax = dict(mu_dtype='bfloat16') + # The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560 + # almost always behaves exactly like adam, but at a fraction of the memory + # cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a + # good idea to try it when you are memory-bound! + # config.optax_name = 'big_vision.scale_by_adafactor' + # A good flag to play with when hitting instabilities, is the following: + # config.optax = dict(beta2_cap=0.95) + + config.lr = 0.001 + config.wd = 0.0001 + config.schedule = dict(warmup_steps=10_000, decay_type='cosine') + + config.mixup = MIXUP_DEF[aug_setting] + + # Eval section + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=2500, # Very fast O(seconds) so it's fine to run it often. + cache='final_data' if arg.runlocal else 'none', + ) + config.evals = {} + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') + config.evals.real.pp_fn = pp_eval.format(lbl='real_label') + + config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal) + config.fewshot.log_steps = 10_000 + + # Make a few things much smaller for quick local debugging testruns. + if arg.runlocal: + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 + config.input.cache_raw = False + config.evals.train.data.split = 'train[:16]' + config.evals.minival.data.split = 'train[:16]' + config.evals.val.data.split = 'validation[:16]' + config.evals.v2.data.split = 'test[:16]' + config.evals.real.data.split = 'validation[:16]' + + return config \ No newline at end of file diff --git a/big_vision/configs/vit_i21k.py b/big_vision/configs/vit_i21k.py new file mode 100644 index 0000000000000000000000000000000000000000..adae41838736be4f4a9737e614152dc5c7fd329b --- /dev/null +++ b/big_vision/configs/vit_i21k.py @@ -0,0 +1,145 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Pre-training ViT on ImageNet-21k as in https://arxiv.org/abs/2106.10270 + +This config relies on the Imagenet-21k tfds dataset, which is not yet +available publicly in TFDS. We intend to add the dataset to public TFDS soon, +and this config will then be runnable. + +Note that regularization (dropout, stochastic depth) is not currently +implemented. This was not beneficial for ImageNet-21k pre-trainning. +""" + +import big_vision.configs.common as bvcc +from big_vision.configs.common_fewshot import get_fewshot_lsr +import ml_collections as mlc + +MIXUP_DEF = { + 'none': dict(p=0.0, fold_in=None), + 'light1': dict(p=0.0, fold_in=None), + 'light2': dict(p=0.2, fold_in=None), + 'medium1': dict(p=0.2, fold_in=None), + 'medium2': dict(p=0.5, fold_in=None), + 'strong1': dict(p=0.5, fold_in=None), + 'strong2': dict(p=0.8, fold_in=None), +} + +RANDAUG_DEF = { + 'none': '', + 'light1': 'randaug(2,0)', # Actually not nothing! + 'light2': 'randaug(2,10)', + 'medium1': 'randaug(2,15)', + 'medium2': 'randaug(2,15)', + 'strong1': 'randaug(2,20)', + 'strong2': 'randaug(2,20)', +} + + +def get_config(arg=None): + """Config for training.""" + arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug=None) + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 300 + config.num_classes = 21843 + config.init_head_bias = -10.0 + config.loss = 'sigmoid_xent' + + # If this gives a KeyError, lookup Fig4 of the paper and add an entry. + # Note, this here is a good average between 30ep and 300ep, sometimes you coud + # find a slightly better setting for either of them. + aug_setting = { + 'Ti/16': 'none', + 'S/32': 'none', + 'S/16': 'light1', + 'B/32': 'light2', + 'B/16': 'light2', + 'L/16': 'medium2', + }[arg.variant] + + config.input = dict() + config.input.data = dict( + name='imagenet21k', + split='full[51200:]', + ) + config.input.batch_size = 4096 + config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. + + pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")' + pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}') + pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"') + config.input.pp = f'decode_jpeg_and_inception_crop(224)|flip_lr|{RANDAUG_DEF[aug_setting]}' + pp_common_i21k + pp_eval = 'decode|resize_small(256)|central_crop(224)' + + # To continue using the near-defunct randaug op. + config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug'] + + # Aggressive pre-fetching because our models here are small, so we not only + # can afford it, but we also need it for the smallest models to not be + # bottle-necked by the input pipeline. Play around with it for -L models tho. + config.input.prefetch = 8 + config.prefetch_to_device = 4 + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + # Model section + config.model_name = 'vit' + config.model = dict(variant=arg.variant, pool_type='gap', posemb='learn') + + # Optimizer section + config.optax_name = 'scale_by_adam' + config.optax = dict(mu_dtype='bfloat16') + config.grad_clip_norm = 1.0 + + config.lr = 0.001 + config.wd = 0.0001 + config.schedule = dict(warmup_steps=10_000, decay_type='cosine') + + config.mixup = MIXUP_DEF[aug_setting] + + # Evaluations on i21k itself. + def eval_i21k(split): + return dict( + type='classification', + data={**config.input.data, 'split': split}, + pp_fn=pp_eval + pp_common_i21k, + loss_name=config.loss, + log_steps=1000, # Very fast O(seconds) so it's fine to run it often. + ) + config.evals = {} + config.evals.test = eval_i21k('full[:25_600]') + config.evals.val = eval_i21k('full[25_600:51_200]') + config.evals.train = eval_i21k('full[51_200:76_800]') + + # Few-shot evaluators + config.evals.fewshot = get_fewshot_lsr(runlocal=arg.runlocal) + config.evals.fewshot.log_steps = 25_000 + + # Make a few things much smaller for quick local debugging testruns. + if arg.runlocal: + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 + config.evals.test.data.split = 'full[:16]' + config.evals.train.data.split = 'full[:16]' + config.evals.val.data.split = 'full[:16]' + config.evals.i1k_val.data.split = 'validation[:16]' + config.evals.i1k_v2.data.split = 'test[:16]' + config.evals.i1k_a.data.split = 'test[:16]' + config.evals.i1k_r.data.split = 'test[:16]' + + return config \ No newline at end of file diff --git a/big_vision/configs/vit_s16_i1k.py b/big_vision/configs/vit_s16_i1k.py new file mode 100644 index 0000000000000000000000000000000000000000..d50dd26508713b67c434f0e677e58fbef7d8af13 --- /dev/null +++ b/big_vision/configs/vit_s16_i1k.py @@ -0,0 +1,105 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Pre-training ViT-S/16 on ILSVRC-2012 following https://arxiv.org/abs/2205.01580. + +This should take 6-7h to finish 90ep on a TPU-v3-8 and reach 76.5%, +see the tech report for more details. + +Command to run: + +big_vision.train \ + --config big_vision/configs/vit_s16_i1k.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` + +To run for 300ep, add `--config.total_epochs 300` to the command. +""" + +import ml_collections as mlc + + +def get_config(): + """Config for training.""" + config = mlc.ConfigDict() + + config.seed = 0 + config.total_epochs = 90 + config.num_classes = 1000 + config.loss = 'softmax_xent' + + config.input = {} + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 1024 + config.input.cache_raw = True # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 + + pp_common = ( + '|value_range(-1, 1)' + '|onehot(1000, key="{lbl}", key_result="labels")' + '|keep("image", "labels")' + ) + config.input.pp = ( + 'decode_jpeg_and_inception_crop(224)|flip_lr|randaug(2,10)' + + pp_common.format(lbl='label') + ) + pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common + + # To continue using the near-defunct randaug op. + config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug'] + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + + # Model section + config.model_name = 'vit' + config.model = dict( + variant='S/16', + rep_size=True, + pool_type='gap', + posemb='sincos2d', + ) + + # Optimizer section + config.grad_clip_norm = 1.0 + config.optax_name = 'scale_by_adam' + config.optax = dict(mu_dtype='bfloat16') + + config.lr = 0.001 + config.wd = 0.0001 + config.schedule = dict(warmup_steps=10_000, decay_type='cosine') + + config.mixup = dict(p=0.2, fold_in=None) + + # Eval section + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=2500, # Very fast O(seconds) so it's fine to run it often. + ) + config.evals = {} + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') + config.evals.real.pp_fn = pp_eval.format(lbl='real_label') + + return config diff --git a/big_vision/datasets/ai2d/ai2d.py b/big_vision/datasets/ai2d/ai2d.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/aokvqa/aokvqa.py b/big_vision/datasets/aokvqa/aokvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/chartqa/chartqa.py b/big_vision/datasets/chartqa/chartqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/coco35l/coco35l.py b/big_vision/datasets/coco35l/coco35l.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/core.py b/big_vision/datasets/core.py new file mode 100644 index 0000000000000000000000000000000000000000..07d2a2c6814646908fc5133cb5a54aec6d3b57b3 --- /dev/null +++ b/big_vision/datasets/core.py @@ -0,0 +1,77 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core data functions, dispatch calls to the requested dataset.""" +import importlib + + +# Note: intentionally not using ABC to avoid forcing implementation of every +# method, since one can imagine train-only datasets for example. +class DataSource: + """The API that any data source should implement.""" + + def get_tfdata(self, ordered, *, process_split=True, allow_cache=True): + """Creates this data object as a tf.data.Dataset. + + This will be called separately in each process, and it is up to the dataset + implementation to shard it accordingly if desired! + + Args: + ordered: if True, the dataset should use deterministic ordering, if False + it may have undefined ordering. Think of True == val, False == train. + process_split: if False then every process receives the entire dataset + (e.g. for evaluators running in a single process). + allow_cache: whether to allow caching the opened data or not. + + Returns: + A tf.data.Dataset object. + + Raises: + RuntimeError: if not implemented by the dataset, but called. + """ + raise RuntimeError("not implemented for {self.__class__.__name__}") + + @property + def total_examples(self): + """Returns number of examples in the dataset, regardless of sharding.""" + raise RuntimeError("not implemented for {self.__class__.__name__}") + + def num_examples_per_process(self): + """Returns a list of the numer of examples for each process. + + This is only needed for datasets that should go through make_for_inference. + + Returns: + Returns a list of the numer of examples for each process. + + Ideally, this would always be `[total() / nprocess] * nprocess`, but in + reality we can almost never perfectly shard a dataset across arbitrary + number of processes. + + One alternative option that can work in some cases is to not even shard + the dataset and thus return `[num_examples()] * nprocess. + + Raises: + RuntimeError: if not implemented by the dataset, but called. + """ + raise RuntimeError("not implemented for {self.__class__.__name__}") + + +def get(name, **kw): + if name.startswith("bv:"): + mod = importlib.import_module(f"big_vision.datasets.{name[3:]}") + return mod.DataSource(**kw) + else: + mod = importlib.import_module("big_vision.datasets.tfds") + return mod.DataSource(name, **kw) diff --git a/big_vision/datasets/countbenchqa/countbenchqa.py b/big_vision/datasets/countbenchqa/countbenchqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/docvqa/docvqa.py b/big_vision/datasets/docvqa/docvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/gqa/gqa.py b/big_vision/datasets/gqa/gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/imagenet/class_names.py b/big_vision/datasets/imagenet/class_names.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/infovqa/infovqa.py b/big_vision/datasets/infovqa/infovqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/jsonl.py b/big_vision/datasets/jsonl.py new file mode 100644 index 0000000000000000000000000000000000000000..719deba2b25987a4b9d58b56474e420cb5b1e706 --- /dev/null +++ b/big_vision/datasets/jsonl.py @@ -0,0 +1,177 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple data input from .jsonl files.""" + +import hashlib +import json +from multiprocessing.pool import ThreadPool +import os +import tempfile +import urllib.request + +from absl import logging +import big_vision.datasets.core as ds_core +import jax +import numpy as np +import overrides +import tensorflow as tf + + +def cached_download(url, dest=None, verbose=True): + """Download `url` to local file and return path to that, but with caching.""" + # NOTE: there is a small chance of saving corrupted data if the process is + # interrupted in the middle of writing the file. Then, reading in the input + # pipeline will fail, and the fix is to nuke the temp folder. + + # Compute a temp name based on the URL, so we can check if we already + # downloaded it before. + dest = dest or os.path.join(tempfile.gettempdir(), "bv") + os.makedirs(dest, exist_ok=True) + dest = os.path.join(dest, hashlib.md5(url.encode()).hexdigest()) + + # NOTE: we should use last-modified header to know whether to re-download. + if os.path.isfile(dest): + return dest + + if verbose: + print(f"\rRetrieving {url} into {dest}", end="", flush=True) + + with urllib.request.urlopen(url) as f: + data = f.read() + with open(dest, "wb+") as f: + f.write(data) + return dest + + +class DataSource(ds_core.DataSource): + """.jsonl DataSource.""" + + def __init__(self, fname, *, fopen_keys=(), download_keys=(), + start=0, stop=float("inf")): + """Create data-source that's jsonl + data files (eg images). + + This correctly supports multi-host in that each host only reads a subset of + the dataset automatically. However, currently, all hosts download all items + if `download_keys` is specified. TODO: b/lbeyer - This can be improved. + + Args: + fname: str, the path to the jsonl file that holds the dataset. + fopen_keys: collection of str or dict, the keys in the dataset whose + string value actually is a file-path that should be opened and read, + and its content is what goes into the batch (eg image filenames + commonly ["image"]). + If a dict, the values are folders prefixed to the filenames. + Supports gs:// for reading from buckets. + download_keys: collection of str, the keys in the dataset whose string + value actually is a URL from which the file should be downloaded first. + files are downloaded to a persistent tmp folder using the URL hash as + filename. If the file already exists, the download is skipped. + Must be a subset of `fopen_keys`. + start: int, index of the first row to use; use for slicing the data. + stop: int or inf, index of the row after the last one to use. + + Note: + This simple data input does not allow for nested/hierarchical values, + or in any way more complicated values like vectors. Use TFDS for that. + + The way start/stop arguments are used is as in list slicing[start:stop]. + """ + self.examples = [] + + with tf.io.gfile.GFile(fname) as f: + for i, line in enumerate(f): + if (start or 0) <= i < (stop or float("inf")): + try: + self.examples.append(json.loads(line)) + except json.decoder.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in line {i}:\n{line}") from e + + if download_keys: + for k in download_keys: + assert k in fopen_keys, ( + f"{k} in download_keys but missing from fopen_keys {fopen_keys}") + + # TODO: b/lbeyer - use info from trainer instead, move that to utils. + logging.info( # pylint: disable=logging-fstring-interpolation + f"\u001b[33mNOTE\u001b[0m: Downloading {download_keys} " + f"for dataset {fname} ({len(self.examples)} examples) ...") + + def _dl_one(ex): + for k in download_keys: + ex[k] = cached_download(ex[k]) + + ThreadPool(100).map(_dl_one, self.examples) + print("Done") + logging.info("\u001b[33mNOTE\u001b[0m: Done downloading.") + + # Normalize. + if isinstance(fopen_keys, (list, tuple)): + self.fopen_keys = {k: "" for k in fopen_keys} + else: + self.fopen_keys = fopen_keys or {} + + # We need to apply fopen path prefix here already, because doing so while + # actually reading the files in TF, things are symbolic :( + for ex in self.examples: + for k, dirname in self.fopen_keys.items(): + ex[k] = os.path.join(dirname, ex[k]) + + def _indices(self, *, process_split=True, process_index=None): + indices = np.arange(len(self.examples)) + + if not process_split: + return list(indices) + + pid = jax.process_index() if process_index is None else process_index + return list(np.array_split(indices, jax.process_count())[pid]) + + @overrides.overrides + def get_tfdata(self, ordered=False, *, process_split=True, allow_cache=True): + del allow_cache # We don't cache anything anyways. + assert not process_split or len(self.examples) >= jax.process_count(), ( + "Process splitting the data with fewer examples than processes!?") + + my_idxs = self._indices(process_split=process_split) + if not ordered: + np.random.shuffle(my_idxs) + + dataset = tf.data.Dataset.from_generator( + generator=lambda: ({"id": str(i), **self.examples[i]} for i in my_idxs), + output_signature={ + "id": _guess_signature("0"), + **{k: _guess_signature(v) for k, v in self.examples[0].items()}, + }) + + def _read_files(example): + for k in self.fopen_keys: + example[k] = tf.io.read_file(example[k]) + return example + dataset = dataset.map(_read_files) + + return dataset + + @property + @overrides.overrides + def total_examples(self): + return len(self.examples) + + @overrides.overrides + def num_examples_per_process(self): + return [len(self._indices(process_index=pid)) + for pid in range(jax.process_count())] + + +def _guess_signature(value): + return tf.TensorSpec.from_tensor(tf.constant(value)) diff --git a/big_vision/datasets/nocaps/nocaps.py b/big_vision/datasets/nocaps/nocaps.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/okvqa/okvqa.py b/big_vision/datasets/okvqa/okvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/pope/pope.py b/big_vision/datasets/pope/pope.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/refcoco/refcoco.py b/big_vision/datasets/refcoco/refcoco.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/rsvqa_hr/rsvqa_hr.py b/big_vision/datasets/rsvqa_hr/rsvqa_hr.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/rsvqa_lr/rsvqa_lr.py b/big_vision/datasets/rsvqa_lr/rsvqa_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/scicap/scicap.py b/big_vision/datasets/scicap/scicap.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/science_qa/science_qa.py b/big_vision/datasets/science_qa/science_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/screen2words/screen2words.py b/big_vision/datasets/screen2words/screen2words.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/sequence_packing.py b/big_vision/datasets/sequence_packing.py new file mode 100644 index 0000000000000000000000000000000000000000..48966d3c488886b3ab0d0f061a1c88c57fdeabae --- /dev/null +++ b/big_vision/datasets/sequence_packing.py @@ -0,0 +1,77 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Packed Sequence Op.""" + +# Forked from +# https://github.com/google/maxtext/blob/main/MaxText/sequence_packing.py. + + +from typing import Dict, Optional, List, Union + +from flax import traverse_util +import tensorflow as tf + +AUTOTUNE = tf.data.experimental.AUTOTUNE +FLATTEN_SEPARATOR = "<|sep|>" + + +def pack_dataset( + dataset: tf.data.Dataset, + batch_size: int | None, + key2length: Union[int, Dict[str, int]], + keys: Optional[List[str | tuple[str, ...]]] = None) -> tf.data.Dataset: + """Creates a 'packed' version of a dataset on-the-fly. + + Wrap `tensorflow.grain` ops. + + This is meant to replace the irritation of having to create a separate + "packed" version of a dataset to train efficiently on TPU. + Each example in the output dataset represents several examples in the + input dataset. + + For each key in the input dataset, two additional keys are created: + _segment_ids: an int32 tensor identifying the parts + representing the original example. + _positions: an int32 tensor identifying the position within the original + example. + + Example: + Two input examples get combined to form an output example. + The input examples are: + {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]} + {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]} + The output example is: + { + "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0] + "inputs_seg": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0] + "inputs_pos": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0] + "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0] + "targets_seg": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0] + "targets_pos": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0] + } + 0 represents padding in both the inputs and the outputs. + Sequences in the incoming examples are truncated to length "length", and the + sequences in the output examples all have fixed (padded) length "length". + + Args: + dataset: A `tf.data.Dataset`. + batch_size: Batch size of the packed dataset. + key2length: An integer, or a dict from feature-key to integer. + keys: A list of strings (e.g. ["inputs", "targets"]). + + Returns: + A `tf.data.Dataset`. + """ + raise ValueError("Not implemented in OSS yet.") diff --git a/big_vision/datasets/stvqa/stvqa.py b/big_vision/datasets/stvqa/stvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/tallyqa/tallyqa.py b/big_vision/datasets/tallyqa/tallyqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/textcaps/textcaps.py b/big_vision/datasets/textcaps/textcaps.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/textvqa/textvqa.py b/big_vision/datasets/textvqa/textvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/tfds.py b/big_vision/datasets/tfds.py new file mode 100644 index 0000000000000000000000000000000000000000..0c15dbc26f46e87d4df27027c1cca5a01b5e74fa --- /dev/null +++ b/big_vision/datasets/tfds.py @@ -0,0 +1,94 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorFlow Datasets as data source for big_vision.""" +import functools + +import big_vision.datasets.core as ds_core +import jax +import numpy as np +import overrides +import tensorflow as tf +import tensorflow_datasets as tfds + + +class DataSource(ds_core.DataSource): + """Use TFDS as a data source.""" + + def __init__(self, name, split, data_dir=None, skip_decode=("image",)): + self.builder = _get_builder(name, data_dir) + self.split = split + # Each host is responsible for a fixed subset of data + process_splits = tfds.even_splits(split, jax.process_count()) + self.process_split = process_splits[jax.process_index()] + self.skip_decode = skip_decode + + @overrides.overrides + def get_tfdata( + self, ordered=False, *, process_split=True, allow_cache=True, **kw): + # The tf.data may use a lot of RAM, so we need to expose the option of not + # keeping this in memory when we use lots of input pipelines, such as when + # having many ephemeral evaluators. + return (_cached_get_dataset if allow_cache else _get_dataset)( + self.builder, self.skip_decode, + split=self.process_split if process_split else self.split, + shuffle_files=not ordered, + **kw) + + @property + @overrides.overrides + def total_examples(self): + return self.builder.info.splits[self.split].num_examples + + @overrides.overrides + def num_examples_per_process(self): + splits = tfds.even_splits(self.split, jax.process_count()) + return [self.builder.info.splits[s].num_examples for s in splits] + + +@functools.cache +def _get_builder(dataset, data_dir): + if dataset == "from_data_dir": + return tfds.builder_from_directory(data_dir) + else: + return tfds.builder(dataset, data_dir=data_dir, try_gcs=True) + + +# Cache as it may well take 1-2min on large datasets, and we may use the same +# multiple times (eg various evaluators). +def _get_dataset(builder, skip_decode, shuffle_files, split=None, **rckw): + """Returns a tf.data to be used.""" + ds = builder.as_dataset( + split=split, shuffle_files=shuffle_files, + read_config=tfds.ReadConfig( + skip_prefetch=True, # We prefetch after pipeline. + try_autocache=False, # We control this, esp. for few-shot. + add_tfds_id=True, + **rckw, + ), + decoders={ + f: tfds.decode.SkipDecoding() + for f in skip_decode if f in builder.info.features + }) + + def _hash_tfds_id(example): + id_ = tf.strings.to_hash_bucket_strong( + example["tfds_id"], + np.iinfo(np.uint32).max, # Max value + [3714561454027272724, 8800639020734831960]) # Magic. + example["_id"] = tf.bitcast(id_, tf.int32)[0] # good device dtype. + return example + + return ds.map(_hash_tfds_id) +_cached_get_dataset = functools.cache(_get_dataset) diff --git a/big_vision/datasets/vizwizvqa/vizwizvqa.py b/big_vision/datasets/vizwizvqa/vizwizvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/vqa/vqa.py b/big_vision/datasets/vqa/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/widgetcap/widgetcap.py b/big_vision/datasets/widgetcap/widgetcap.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/xgqa/xgqa.py b/big_vision/datasets/xgqa/xgqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/datasets/xm3600/xm3600.py b/big_vision/datasets/xm3600/xm3600.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/evaluators/__init__.py b/big_vision/evaluators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/evaluators/classification.py b/big_vision/evaluators/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..263ead8f5027f4b8e640b9ba42a72b3cbc33adf2 --- /dev/null +++ b/big_vision/evaluators/classification.py @@ -0,0 +1,76 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for the classfication task.""" +# pylint: disable=consider-using-from-import + +import functools + +from big_vision.evaluators import common +import big_vision.utils as u +import jax +import jax.numpy as jnp + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = 'jit' + + +# To avoid re-compiling the function for every new instance of the same +# evaluator on a different dataset! +@functools.cache +def get_eval_fn(predict_fn, loss_name): + """Produces eval function, also applies pmap.""" + @jax.jit + def _eval_fn(train_state, batch, labels, mask): + logits, *_ = predict_fn(train_state, batch) + + # Ignore the entries with all zero labels for evaluation. + mask *= labels.max(axis=1) + + loss = getattr(u, loss_name)( + logits=logits, labels=labels, reduction=False) + loss = jnp.sum(loss * mask) + + top1_idx = jnp.argmax(logits, axis=1) + # Extracts the label at the highest logit index for each image. + top1_correct = jnp.take_along_axis( + labels, top1_idx[:, None], axis=1)[:, 0] + ncorrect = jnp.sum(top1_correct * mask) + nseen = jnp.sum(mask) + return ncorrect, loss, nseen + return _eval_fn + + +class Evaluator: + """Classification evaluator.""" + + def __init__(self, predict_fn, loss_name, label_key='labels', **kw): + self.get_data_iter, self.steps = common.eval_input_pipeline(**kw) + self.eval_fn = get_eval_fn(predict_fn, loss_name) + self.label_key = label_key + + def run(self, train_state): + """Computes all metrics.""" + ncorrect, loss, nseen = 0, 0, 0 + for _, batch in zip(range(self.steps), self.get_data_iter()): + labels, mask = batch.pop(self.label_key), batch.pop('_mask') + batch_ncorrect, batch_losses, batch_nseen = jax.device_get( + self.eval_fn(train_state, batch, labels, mask)) + ncorrect += batch_ncorrect + loss += batch_losses + nseen += batch_nseen + yield ('prec@1', ncorrect / nseen) + yield ('loss', loss / nseen) diff --git a/big_vision/evaluators/common.py b/big_vision/evaluators/common.py new file mode 100644 index 0000000000000000000000000000000000000000..42dcdbb4b52a5208673821b9c68df246709fcf6d --- /dev/null +++ b/big_vision/evaluators/common.py @@ -0,0 +1,228 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for evaluators in general.""" + +import dataclasses +import functools +import importlib +import json +import os +from typing import Any, Callable + +from absl import flags +from big_vision import input_pipeline +from big_vision.datasets import core as ds_core +from big_vision.pp import builder as pp_builder +import big_vision.utils as u +import flax +import jax +import numpy as np + +from tensorflow.io import gfile + + +def from_config(config, predict_fns, + write_note=lambda s: s, + get_steps=lambda key, cfg: cfg[f"{key}_steps"], + devices=None): + """Creates a list of evaluators based on `config`.""" + evaluators = [] + specs = config.get("evals", {}) + + for name, cfg in specs.items(): + write_note(name) + + # Pop all generic settings off so we're left with eval's kwargs in the end. + cfg = cfg.to_dict() + module = cfg.pop("type", name) + pred_key = cfg.pop("pred", "predict") + pred_kw = cfg.pop("pred_kw", None) + prefix = cfg.pop("prefix", f"{name}/") + cfg.pop("skip_first", None) + logsteps = get_steps("log", cfg) + for typ in ("steps", "epochs", "examples", "percent"): + cfg.pop(f"log_{typ}", None) + + # Use same batch_size as eval by default, to reduce fragmentation. + # TODO: eventually remove all the deprecated names... + cfg["batch_size"] = cfg.get("batch_size") or config.get("batch_size_eval") or config.get("input.batch_size") or config.get("batch_size") # pylint: disable=line-too-long + + module = importlib.import_module(f"big_vision.evaluators.{module}") + + if devices is not None: + cfg["devices"] = devices + + api_type = getattr(module, "API", "pmap") + if api_type == "pmap" and "devices" in cfg: + raise RuntimeError( + "You are seemingly using the old pmap-based evaluator, but with " + "jit-based train loop, see (internal link) for more details.") + if api_type == "jit" and "devices" not in cfg: + raise RuntimeError( + "You are seemingly using new jit-based evaluator, but with " + "old pmap-based train loop, see (internal link) for more details.") + + try: + predict_fn = predict_fns[pred_key] + except KeyError as e: + raise ValueError( + f"Unknown predict_fn '{pred_key}'. Available predict_fns are:\n" + + "\n".join(predict_fns)) from e + if pred_kw is not None: + predict_fn = _CacheablePartial(predict_fn, flax.core.freeze(pred_kw)) + evaluator = module.Evaluator(predict_fn, **cfg) + evaluators.append((name, evaluator, logsteps, prefix)) + + return evaluators + + +@dataclasses.dataclass(frozen=True, eq=True) +class _CacheablePartial: + """partial(fn, **kwargs) that defines hash and eq - to help with jit caches. + + This is particularly common in evaluators when one has many evaluator + instances that run on difference slices of data. + + Example: + + ``` + f1 = _CacheablePartial(fn, a=1) + jax.jit(f1)(...) + jax.jit(_CacheablePartial(fn, a=1))(...) # fn won't be retraced. + del f1 + jax.jit(_CacheablePartial(fn, a=1))(...) # fn will be retraced. + ``` + """ + fn: Callable[..., Any] + kwargs: flax.core.FrozenDict + + def __call__(self, *args, **kwargs): + return functools.partial(self.fn, **self.kwargs)(*args, **kwargs) + + +def eval_input_pipeline( + data, pp_fn, batch_size, devices, keep_on_cpu=(), + cache="pipeline", prefetch=1, warmup=False, +): + """Create an input pipeline in the way used by most evaluators. + + Args: + data: The configuration to create the data source (like for training). + pp_fn: A string representing the preprocessing to be performed. + batch_size: The batch size to use. + devices: The devices that the batches are sharded and pre-fetched onto. + keep_on_cpu: See input_pipeline.start_global. Entries in the batch that + should be kept on the CPU, hence could be ragged or of string type. + cache: One of "none", "pipeline", "raw_data", "final_data". Determines what + part of the input stream should be cached across evaluator runs. They use + more and more RAM, but make evals faster, in that order. + - "none": Entirely re-create and destroy the input pipeline each run. + - "pipeline": Keep the (tf.data) pipeline object alive across runs. + - "raw_data": Cache the full raw data before pre-processing. + - "final_data": Cache the full raw data after pre-processing. + prefetch: How many batches to fetch ahead. + warmup: Start fetching the first batch at creation time (right now), + instead of once the iteration starts. + + Returns: + A tuple (get_iter, steps), the first element is a function that returns + the iterator to be used for an evaluation, the second one is how many steps + should be iterated for doing one evaluation. + """ + assert ( + cache is None + or cache.lower() in ("none", "pipeline", "raw_data", "final_data") + ), f"Unknown value for cache: {cache}" + data_source = ds_core.get(**data) + tfdata, steps = input_pipeline.make_for_inference( + data_source.get_tfdata(ordered=True, allow_cache=cache.lower() != "none"), + batch_size=batch_size, + num_ex_per_process=data_source.num_examples_per_process(), + preprocess_fn=pp_builder.get_preprocess_fn(pp_fn, str(data)), + cache_final=cache == "raw_data", + cache_raw=cache == "final_data") + get_data_iter = lambda: input_pipeline.start_global( + tfdata, devices, prefetch, keep_on_cpu, warmup) + + # Possibly create one persistent iterator: + if cache in ("pipeline", "raw_data", "final_data"): + data_iter = get_data_iter() + get_data_iter = lambda: data_iter + + return get_data_iter, steps + + +def process_sum(tree): + """Sums the pytree across all processes.""" + if jax.process_count() == 1: # Avoids corner-cases on donuts. + return tree + + with jax.transfer_guard_device_to_host("allow"): + gathered = jax.experimental.multihost_utils.process_allgather(tree) + return jax.tree.map(functools.partial(np.sum, axis=0), gathered) + + +def resolve_outfile(outfile, split="", **kw): + if not outfile: + return None + + # A caveat: when workdir doesn't exist but is in the `outfile`, we should + # skip. This is common in small runs or runlocal debuggings. + if "{workdir}" in outfile and not flags.FLAGS.workdir: + return None + + return outfile.format( + workdir=flags.FLAGS.workdir, + split="".join(c if c not in "[]%:" else "_" for c in split), + step=getattr(u.chrono, "prev_step", None), + **kw, + ) + + +def multiprocess_write_json(outfile, jobj): # jobj = "json object" + """Write a single json file combining all processes' `jobj`s.""" + if not outfile: + return + + outfile = resolve_outfile(outfile) + gfile.makedirs(os.path.dirname(outfile)) + + if isinstance(jobj, list): + combine_fn = list.extend + elif isinstance(jobj, dict): + combine_fn = dict.update + else: + raise TypeError(f"Can only write list or dict jsons, but got {type(jobj)}") + + # First, each process writes its own file. + with gfile.GFile(outfile + f".p{jax.process_index()}", "w+") as f: + f.write(json.dumps(jobj)) + + u.sync() # Wait for all files to be written; `with` above does close/flush. + + # Have process 0 collect, concat, and write final output. + all_json = type(jobj)() + if jax.process_index() == 0: + for pid in range(jax.process_count()): + with gfile.GFile(outfile + f".p{pid}", "r") as f: + combine_fn(all_json, json.loads(f.read())) + with gfile.GFile(outfile, "w+") as f: + f.write(json.dumps(all_json)) + + # Cleanup time + u.sync() + gfile.remove(outfile + f".p{jax.process_index()}") + + return all_json diff --git a/big_vision/evaluators/fewshot_lsr.py b/big_vision/evaluators/fewshot_lsr.py new file mode 100644 index 0000000000000000000000000000000000000000..1b7019ad3fa58936975b631206947b3b33ecdc67 --- /dev/null +++ b/big_vision/evaluators/fewshot_lsr.py @@ -0,0 +1,245 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for few-shot evaluation.""" +# pylint: disable=consider-using-from-import,g-importing-member + +import functools + +import big_vision.datasets.core as ds_core +import big_vision.input_pipeline as input_pipeline +import big_vision.pp.builder as pp_builder +import big_vision.utils as u +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding as Sharding +from jax.sharding import PartitionSpec as P +import numpy as np + +BIAS_CONSTANT = 100.0 + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +# Setup function for few-shot regression on CPU to avoid "polluting" the TPU. +@u.jit_cpu(static_argnums=(2,)) +def _precompute_cache(x, y, num_classes): + """Cache quantities to speed-up the computation of L2-regularized least-sq.""" + # Whiten + mean = jnp.mean(x, axis=0, keepdims=True) + std = jnp.std(x, axis=0, keepdims=True) + 1e-5 + x = (x - mean) / std + + # Add a constant feature for the bias, large so it's almost unregularized: + x = jnp.pad(x, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT) + + # To one-hot representation rescaled into {-1, 1} + y = 2.0 * jax.nn.one_hot(y, num_classes) - 1.0 + + num_points, dim = x.shape + # Let N be the number of points, D the dimension and C the number of classes. + # We have x of shape (N, D) and y of shape (N, C). + # For least-squares, we can compute + # + # (A) when N >= D, (x^T x + l2 Id)^{-1} x^T y + # (B) when D > N, x^T (x x^T + l2 Id)^{-1} y + # + # We pre-compute the eigen-decomposition of either x^T x or x x^T which + # becomes q diag(eigs) q^T with q unitary matrix either (D, D) or (N, N) + # and eigs a vector (D,) or (N,). + # + # For any l2 > 0, we can compute (x^T x + l2 Id)^{-1} or (x x^T + l2 Id)^{-1} + # by simply computing q (diag(eigs) + l2 Id)^{-1} q^T. + # (SVD would be more natural here, but it proved slower, so we use eigh) + # + # Both cases (A) and (B) can be viewed as lhs (diag(eigs) + l2 Id)^{-1} rhs, + # where lhs/rhs are pre-computed left/right-hand sides to specify. + # + # Detailed evaluation in terms of time and fewshot metrics can be found in + # (internal link) + # + # Implemented by Rodolphe Jenatton. + if num_points >= dim: + eigs, q = jnp.linalg.eigh(x.T @ x) + rhs = q.T @ (x.T @ y) + lhs = q + else: + eigs, q = jnp.linalg.eigh(x @ x.T) + rhs = q.T @ y + lhs = x.T @ q + + cache = { + "eigs": eigs, + "rhs": rhs, + "lhs": lhs, + "mean": mean, + "std": std + } + return cache + + +@u.jit_cpu() +def _eig_fewshot_acc_fn(cache, x_test, y_test, l2_reg): + """Computes (x,y) linear regression accuracy on (x_test, y_test).""" + + x_test = (x_test - cache["mean"]) / cache["std"] + x_test = jnp.pad(x_test, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT) + + rhs = cache["rhs"] + lhs = cache["lhs"] + eigs = cache["eigs"] + + # See comments in _precompute_cache for context about the formula. + scaling = 1.0 / (eigs + l2_reg * jnp.ones_like(eigs)) + scaling = scaling.reshape((1, -1)) + w = (lhs * scaling) @ rhs + # Predict test-set values and measure their accuracy + preds = jnp.argmax(x_test @ w, axis=1) + return jnp.mean(preds == y_test) + + +class Evaluator: + """Class for few-shot evaluation.""" + + def __init__(self, predict_fn, batch_size, + datasets, shots, l2_reg, + pp_train, pp_eval, display_first, + representation_layer=None, num_seeds=3, + label_key="label", mask_key="_mask", data_dir=None, *, + devices): + self.datasets = datasets + self.shots = shots + self.l2_reg = l2_reg + self.batch_size = batch_size + self.pp_tr = pp_train + self.pp_te = pp_eval + self.display_first = display_first + self._datasets = {} # Cache for tfds data. Persists while object is alive. + self._repr = {} # Cache for precomputed repr. Persists within the run call. + self.num_seeds = num_seeds + self.label_key = label_key + self.mask_key = mask_key + self.data_dir = data_dir + self.devices = devices + self.mesh = jax.sharding.Mesh(devices, ("devices",)) + self.repr_fn = self.get_representation_fn( + predict_fn, representation_layer) + + def get_representation_fn(self, predict_fn, representation_layer): + # `out_shardings=Sharding(self.mesh, P())` will "all_gather" the outputs. + @functools.partial(jax.jit, out_shardings=Sharding(self.mesh, P())) + def _repr_fn(train_state, batch, labels, mask): + zimg, *_, out = predict_fn(train_state, batch) + if representation_layer is not None: + rep = u.tree_get(out, representation_layer) + else: + rep = zimg + return rep, labels, mask + return _repr_fn + + # Setup input pipeline. + def _get_dataset(self, dataset, train_split, test_split): + """Lazy-loads given dataset.""" + key = (dataset, train_split, test_split) + try: + return self._datasets[key] + except KeyError: + # NOTE: only supporting TFDS data for now for bwd compat/lazyness. + train_data = ds_core.get( + name=dataset, split=train_split, data_dir=self.data_dir + ) + test_data = ds_core.get( + name=dataset, split=test_split, data_dir=self.data_dir + ) + train_ds, batches_tr = input_pipeline.make_for_inference( + train_data.get_tfdata(ordered=True), + num_ex_per_process=train_data.num_examples_per_process(), + batch_size=self.batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(self.pp_tr)) + test_ds, batches_te = input_pipeline.make_for_inference( + test_data.get_tfdata(ordered=True), + num_ex_per_process=test_data.num_examples_per_process(), + batch_size=self.batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(self.pp_te)) + + num_classes = train_data.builder.info.features[self.label_key].num_classes + return self._datasets.setdefault( + key, (train_ds, batches_tr, test_ds, batches_te, num_classes)) + + def _get_repr(self, params, data, steps): + """Compute representation for the whole dataset.""" + pre_logits_list = [] + labels_list = [] + for batch, _ in zip( + input_pipeline.start_global(data, self.devices, 0), range(steps)): + labels, mask = batch.pop(self.label_key), batch.pop(self.mask_key) + pre_logits, labels, mask = jax.device_get(self.repr_fn( + params, batch, labels, mask)) + mask = mask.astype(bool) + pre_logits_list.append(pre_logits[mask]) + labels_list.append(labels[mask]) + pre_logits = np.concatenate(pre_logits_list, axis=0) + labels = np.concatenate(labels_list, axis=0) + + return pre_logits, labels + + def compute_fewshot_metrics(self, train_state, seed, + dataset, train_split, test_split): + """Compute few-shot metrics on one dataset.""" + if dataset in self._repr: + repr_train, labels_train, repr_test, labels_test, num_classes = ( + self._repr[dataset]) + else: + train_ds, steps_tr, test_ds, steps_te, num_classes = self._get_dataset( + dataset, train_split, test_split) + repr_train, labels_train = self._get_repr(train_state, train_ds, steps_tr) + repr_test, labels_test = self._get_repr(train_state, test_ds, steps_te) + self._repr[dataset] = (repr_train, labels_train, + repr_test, labels_test, + num_classes) + + # Collect where we have samples of which classes. + rng = np.random.default_rng(seed) + class_indices = [rng.permutation(np.where(labels_train == cls_i)[0]) + for cls_i in range(num_classes)] + + results = {} + for shots in self.shots: + all_idx = [indices[:shots] for indices in class_indices] + all_idx = np.concatenate(all_idx, axis=0) + x = u.put_cpu(repr_train[all_idx]) + y = u.put_cpu(labels_train[all_idx]) + repr_test, labels_test = u.put_cpu((repr_test, labels_test)) + + # Note the code is optimized to solve multiple LSR tasks for changing l2 + # strength, even though we currently used the fixed l2_reg constant. + cache = _precompute_cache(x, y, num_classes) + acc = _eig_fewshot_acc_fn( + cache, repr_test, labels_test, u.put_cpu(self.l2_reg)) + results[shots] = jax.device_get(acc) + + return results + + def run(self, train_state): + """New API executed in terms of old API.""" + self._repr = {} + for seed in range(self.num_seeds): + for name, dataset_args in self.datasets.items(): + result = self.compute_fewshot_metrics(train_state, seed, *dataset_args) + for shots, v in result.items(): + prefix = "a/" if (name, shots) in self.display_first else "z/" + suffix = f"-seed-{seed}" + yield f"{prefix}{name}_{shots}shot{suffix}", v diff --git a/big_vision/evaluators/mean.py b/big_vision/evaluators/mean.py new file mode 100644 index 0000000000000000000000000000000000000000..a38fb21d3cd7ab7d37a5734c67640994c0956b36 --- /dev/null +++ b/big_vision/evaluators/mean.py @@ -0,0 +1,80 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for computing mean of per-example metrics. + +This evaluator can be used in two ways: + 1. Create a new evaluator with reduced boilerplate by inheriting from it. + 2. For quick prototyping, use this with predict_fns which return the metrics. +""" +from functools import partial +from typing import Mapping + +from big_vision.evaluators import common + +import jax +import jax.numpy as jnp +import numpy as np + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = 'jit' + + +# Note: global to avoid jax re-compiling across different evaluator instances. +@partial(jax.jit, static_argnums=0) +def _run_predict_fn(predict_fn, train_state, batch): + """Sum per-example metrics weighted by `_mask`.""" + metrics = predict_fn(train_state, batch) + mask = batch['_mask'] + # Sanity check output format of predict_fn. + assert isinstance(metrics, Mapping), 'predict_fn must return a dict' + for y in jax.tree.leaves(metrics): + if y.shape != mask.shape: + raise ValueError( + f'Expected per-example metrics of shape {mask.shape} found ' + f'{jax.tree.map(lambda x: x.shape, metrics)}.') + metrics = {**metrics, '_mask': mask} + return jax.tree.map(lambda x: jnp.sum(jnp.where(mask, x, 0)), metrics) + + +class Evaluator: + """Report the mean of per-example metrics computed by predict_fn. + + `predict_fn(params, batch)` must return a dict from metric name to + per-example metrics of shape [batch_size]. + """ + + def __init__(self, predict_fn, **kw): + self.get_data_iter, self.steps = common.eval_input_pipeline(**kw) + self.predict_fn = partial(_run_predict_fn, predict_fn) + + def run(self, train_state): + """Computes all metrics.""" + metrics = [] + + # Compute batch metrics without blocking. + for _, batch in zip(range(self.steps), self.get_data_iter()): + batch_metrics = self.predict_fn(train_state, batch) + metrics.append(batch_metrics) + + # Transfer metrics (blocking). + metrics = jax.device_get(metrics) + + # Accumulate metrics across batches. + metrics_sum = jax.tree.map(lambda *x: np.sum(x), *metrics) + mask_sum = metrics_sum.pop('_mask') + for key, value_sum in metrics_sum.items(): + yield (key, value_sum / mask_sum) diff --git a/big_vision/evaluators/save.py b/big_vision/evaluators/save.py new file mode 100644 index 0000000000000000000000000000000000000000..49bcfc59b9fd9c613611b1edcbd157b2d8c2d6d5 --- /dev/null +++ b/big_vision/evaluators/save.py @@ -0,0 +1,121 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator that save inputs and outputs of prediction functions.""" +import functools + +from absl import flags +from absl import logging + +from big_vision import input_pipeline +from big_vision import optax as bv_optax +from big_vision import utils +from big_vision.datasets import core as ds_core +from big_vision.pp import builder as pp_builder + +import jax +import numpy as np + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = 'jit' + + +# Note: global to avoid jax re-compiling across different evaluator instances. +def _run_predict_fn(predict_fn, train_state, batch): + """Run predict_fn and gather all outputs on all devices.""" + y = predict_fn(train_state, batch) + return {'inputs': batch, 'outputs': y} + + +class Evaluator: + """Evaluator that saves the inputs and outputs of a prediction function. + + Example configuration: + + ``` + config.evals.save_pred = { + 'type': 'save', + 'pred': 'inference', + 'outfile': '{workdir}/inference-{step:09d}.npz', + 'data': ..., 'pp_fn': ..., 'log_steps': ..., + } + ``` + + Results can then be easily inspected in a notebook such as: + + ``` + results = utils.load_checkpoint("") + inputs, outputs = (results["inputs"], results["outputs"]) + ``` + """ + + def __init__(self, predict_fn, data, pp_fn, batch_size, outfile, + cache_final=True, cache_raw=False, prefetch=1, *, devices): + replicate = jax.sharding.NamedSharding( + jax.sharding.Mesh(devices, ('devices',)), + jax.sharding.PartitionSpec() + ) + self.predict_fn = functools.partial( + jax.jit(_run_predict_fn, static_argnums=0, out_shardings=replicate), + predict_fn, + ) + + data = ds_core.get(**data) + self.dataset, self.steps = input_pipeline.make_for_inference( + data.get_tfdata(ordered=True), + batch_size=batch_size, + num_ex_per_process=data.num_examples_per_process(), + preprocess_fn=pp_builder.get_preprocess_fn(pp_fn), + cache_final=cache_final, + cache_raw=cache_raw, + ) + self.data_iter = input_pipeline.start_global( + self.dataset, devices, prefetch + ) + + self.outfile = outfile + + def run(self, train_state): + """Compute all predictions, gather in main host and save in outfile.""" + step = jax.device_get(bv_optax.get_count(train_state['opt'], jittable=True)) + outfile = self.outfile.format(workdir=flags.FLAGS.workdir, step=step) + + count = 0 + outputs = [] + for _, batch in zip(range(self.steps), self.data_iter): + out = self.predict_fn(train_state, batch) + if jax.process_index(): + continue + + out = jax.device_get(out) + mask = out['inputs']['_mask'] + out = jax.tree.map(lambda x: x[mask == 1], out) # pylint: disable=cell-var-from-loop + count += mask.shape[0] + out['inputs'].pop('_mask') + outputs.append(out) + + logging.log_every_n_seconds( + logging.INFO, 'Processed %i examples so far.', 60, + count) + + if jax.process_index(): + return + + logging.info('Saving %d examples in %s', count, outfile) + outputs = jax.tree.map(lambda *x: np.concatenate(x, axis=0), *outputs) + utils.save_checkpoint(outputs, outfile, compressed=True) + return + + yield None # pylint: disable=unreachable diff --git a/big_vision/input_pipeline.py b/big_vision/input_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..afe20894c3ac5fee4e0ef6bd549151b79f4f224c --- /dev/null +++ b/big_vision/input_pipeline.py @@ -0,0 +1,357 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ImageNet input pipeline.""" +import collections +import functools +import itertools +import math +import multiprocessing.pool + +from absl import logging +from big_vision.datasets import sequence_packing +import big_vision.datasets.core as ds_core +import big_vision.pp.builder as pp_builder +import big_vision.utils as u +import einops +import jax +import numpy as np +import tensorflow as tf + + +DEFAULT_NUM_PARALLEL_CALLS = 100 + + +def make_for_train( + data, preprocess_fn, batch_size, + shuffle_buffer_size=None, cache_raw=False, + num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=2, + *, + pre_filter_fn=None, post_filter_fn=None, + pack=None, skip_errors=False, +): + """Makes an input pipeline for training.""" + # Use data filtering at your own risk: the actual split sizes won't be known + # in advance, so epoch-based things won't work correctly. + + data = _add_tpu_host_options(data) + + data = data.filter(pre_filter_fn) if pre_filter_fn else data + data = data.cache() if cache_raw else data + + # First shuffle and then repeat (each with a different shuffle). This way + # the data for one epoch is all seen before the next one is processed and + # significantly affects the number of times each example is seen when + # processing for small number of epochs. + if shuffle_buffer_size: + data = data.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True) + data = data.repeat(None) + + data = data.map(preprocess_fn, num_parallel_calls=num_parallel_calls) + data = data.filter(post_filter_fn) if post_filter_fn else data + + data = data.ignore_errors(log_warning=True) if skip_errors else data + + if pack: + data = sequence_packing.pack_dataset( + data, + batch_size // jax.process_count() if batch_size else None, + pack.to_dict()) + + # Drop remainder makes shape fully static, so we can later use it if needed. + if batch_size: + data = data.batch(batch_size // jax.process_count(), drop_remainder=True) + if prefetch: # None means autotune, but we never want that. + data = data.prefetch(prefetch) + return data + + +def training(input_config): + """Reads the data from a single dataset, or mixes it from multiple. + + The data is read either from one or mixed from multiple datasets, depending + on the `input_config`. + + Args: + input_config: Configures the input pipeline. See input_pipeline_test for + examples. + + Returns: + A tuple containing (possibly mixed) tf.data.Dataset and a total number of + training examples. + """ + per_pipeline_configs = ( + "shuffle_buffer_size", "cache_raw", "num_parallel_calls", + "pre_filter_fn", "post_filter_fn", "pack", "skip_errors") + def config_to_kw(config): + assert "filter_fn" not in config, "Deprecated; use `pre_filter_fn` instead." + return {k: config[k] for k in per_pipeline_configs if k in config} + + batch_size = input_config.batch_size + # Handle separately the common case when no mixing happens. + if isinstance(input_config.data.get("name"), str): + train_data = ds_core.get(**input_config.data) + train_ds = make_for_train( + data=train_data.get_tfdata(ordered=False, + **input_config.get("tfdata", {})), + batch_size=batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(input_config.get("pp")), + prefetch=input_config.get("prefetch", 2), # Default 2 for bwd compat. + **config_to_kw(input_config) + ) + return train_ds, train_data.total_examples + + # A helpful error instead of silent ignore: + for k in per_pipeline_configs: + assert k not in input_config, f"{k} is per-dataset in multi-input." + + # Parallelize the loading of datasets when doing data mixture. + # For larger mixes, we sometimes spend >5min when doing sequentially. + # NOTE: functools.cache is thread-safe. + def _make(name_and_weight): + name, weight = name_and_weight + dataset = input_config[name] + train_data = ds_core.get(**dataset.data) + dataset = make_for_train( + data=train_data.get_tfdata(ordered=False, **dataset.get("tfdata", {})), + # Don't batch the data just yet, it will be done after + # mixing the different datasets below. + batch_size=None, + preprocess_fn=pp_builder.get_preprocess_fn(dataset.get("pp"), name), + prefetch=0, # Prefetching each pipeline leads to huge OOMs. + **config_to_kw(dataset) + ) + if keys := input_config.get("keep_only"): + dataset = dataset.map(lambda d, keys=keys: {k: d[k] for k in keys}) + return name, dataset, weight, train_data.total_examples + + names, datasets, weights, totals = [], [], [], [] + pool = multiprocessing.pool.ThreadPool( + input_config.get("thread_pool_size", len(input_config.data)) + ) + for name, dataset, weight, total in pool.map( + # Skip weight=0 datasets as a convenient optimization in sweeps. + _make, ((name, w) for name, w in input_config.data.items() if w)): + names.append(name) + datasets.append(dataset) + weights.append(weight) + totals.append(total) + + # Normalize the weights such that they sum up to 1. + weights = [x / sum(weights) for x in weights] + + logging.info( + "NOTE: Total dataset mix size: %d\nContributions:\n%s", sum(totals), + "\n".join(f"{ds}: {n} ({w * 100:.2g}%)" + for ds, n, w in zip(names, totals, weights)) + ) + + train_ds = tf.data.Dataset.sample_from_datasets( + datasets, weights, stop_on_empty_dataset=True) + if input_config.get("pack"): + train_ds = sequence_packing.pack_dataset( + train_ds, + input_config["batch_size"] // jax.process_count(), + input_config.pack.to_dict()) + + train_ds = train_ds.batch( + input_config["batch_size"] // jax.process_count(), drop_remainder=True) + if (pf := input_config.get("prefetch", 2)): + train_ds = train_ds.prefetch(pf) + + return train_ds, sum(totals) + + +# The pipeline below is used for evals in multi-{G,T}PU and multi-host settings. +# As the total number of examples may not be evenly divisible accross all +# devices, we use the `infinite tf.data padding` trick, which was suggested by +# Andreas Steiner and also implemented by him in the clu library: +# https://github.com/google/CommonLoopUtils/blob/84b777c42dfd3fb6685537138433bfeb5241a006/clu/deterministic_data.py#L304. +def make_for_inference( + data, preprocess_fn, batch_size, num_ex_per_process, + cache_raw=False, cache_final=False, + num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=1, +): + """Makes an input pipeline for inference.""" + + data = _add_tpu_host_options(data) + data = data.cache() if cache_raw else data + data = data.map(_add_internal_fields(preprocess_fn), + num_parallel_calls=num_parallel_calls) + data = data.concatenate(_get_pad_data(data)) + + local_batch_size = batch_size // jax.process_count() + # This is just like `batch`, but allows batching elements of different shapes + # into a tf.RaggedTensor. Elements of the same fixed shape remain tf.Tensors. + # Since we do 'infinite' padding it is safe to drop the remainder. + data = data.ragged_batch(batch_size=local_batch_size, drop_remainder=True) + + # We need to make sure that all hosts process all data and exactly the same + # number of batches. Below we take max per-host num examples and use it on all + # hosts to derive the number of batches. + num_batches = math.ceil(max(num_ex_per_process) / local_batch_size) + data = data.take(num_batches) + + # Note we cache data after a finite number of batches is taken. + data = data.cache() if cache_final else data + data = data.repeat() + data = data.prefetch(prefetch) if prefetch else data + return data, num_batches + + +def _get_pad_data(data): + def zeros_like_spec(spec): + # For unknown/flexible dimensions (None), just use 0 instead. + return tf.zeros([x or 0 for x in spec.shape], spec.dtype) + + zero = jax.tree.map(zeros_like_spec, data.element_spec) + return tf.data.Dataset.from_tensors(zero).repeat() + + +def _add_internal_fields(pp_fn): + """Wraps pp_fn to add _mask and _id keys.""" + # Adds internal keys, that we either, in this order of preference: + # 1. keep from result of pp_fn, + # 2. carry over from raw (not pp_fn'd) example, or + # 3. add, if that makes sense. + def _pp_fn(example): + result = pp_fn(example) + # _mask will be False on padded examples (see _get_pad_data). + result.setdefault("_mask", example.get("_mask", tf.constant(True))) + # Not all data-sources can provide an ID. Only carry-over if it can: + if "_id" in example and "_id" not in result: + result["_id"] = example["_id"] + return result + return _pp_fn + + +def _add_tpu_host_options(data): + options = tf.data.Options() + options.threading.private_threadpool_size = 48 + options.threading.max_intra_op_parallelism = 1 + + # Stop a whole bunch of magic stuff that eats up all RAM: + options.experimental_optimization.inject_prefetch = False + + return data.with_options(options) + + +def prefetch_iterator(it, n): + """Runs iterator `it` ahead for `n` steps. Adapted from flax.""" + if not n: + yield from it + return + queue = collections.deque() + + def enqueue(n_steps): # Enqueues *up to* `n` elements from the iterator. + for data in itertools.islice(it, n_steps): + # Prefetching will parallelize any processing that happens in a different + # thread (like `jax.device_put()`), but it will be of no use for + # processing that happens in the same thread. + queue.append(data) + + enqueue(n) # Fill up the buffer. + while queue: + yield queue.popleft() + enqueue(1) + + +def threadstart_iterator(it): + """Starts an iterator right away in a background thread.""" + # We already want to "start" the iterator in order to start the underlying + # dataset prefetch mechanisms, so here we get the first element. But we don't + # want to lose it from training, so we yield that one afterwards. + # (internal link) + pool = multiprocessing.pool.ThreadPool(processes=1) + first_ex_promise = pool.apply_async(lambda: next(it)) + + yield first_ex_promise.get() + yield from it + + +def tf_to_numpy(x): + """Convert any TF types to numpy.""" + if isinstance(x, tf.Tensor): + if x.dtype != tf.string: # Dense, non-string tensor? Easy! + return x.numpy() + else: # A dense string tensor? Turn into actual strings, not bytes. + return np.vectorize(bytes.decode, otypes=[str])(x.numpy()) + + # The rest deals with RaggedTensors, for two main reasons: + # - For strings, recursively apply the above conversion + # - For common cases (eg batch of images), return more reasonable shapes. + + # Replace all None's in the shape by a fixed number, in the (somewhat common) + # case that they are marked ragged, but really all have the same shape. + real_shape = list(x.shape) + for i, s in enumerate(real_shape[1:]): + if s is not None: continue + rowlens = np.diff(x.nested_row_splits[i]) + if len(set(rowlens)) == 1: + real_shape[i + 1] = rowlens[0] + + if None not in real_shape: + return tf_to_numpy(x.flat_values).reshape(real_shape) + + # It's actually ragged, reconstruct the array from the variable length pieces. + splits = x.row_splits.numpy() + rows = [tf_to_numpy(x.values[splits[i]:splits[i + 1]]) + for i in range(len(splits) - 1)] + return np.fromiter(rows, dtype=object) + + +# Note that the order of global devices for sharding data is important and +# should be compatible with device order used for models params, state, etc. +def start_global( + data, global_devices, n_prefetch=1, keep_on_cpu=frozenset(), warmup=False): + """Starts the global input pipeline.""" + def maybe_shard(name, x): + if name in keep_on_cpu: + return tf_to_numpy(x) + return u.make_fsarray_from_local_slice(x, global_devices) + + it = iter(data) + if warmup: # actually pre-fill shuffle buffers etc. + it = threadstart_iterator(it) + + it = (u.tree_map_with_names(maybe_shard, elem) for elem in it) + return prefetch_iterator(it, n_prefetch) + + +########################################################################## +# The code below is pmap-specific and is deprecated, please switch to jit. +########################################################################## + + +def shard_and_put(x, shard=True, put=True): + x = np.asarray(memoryview(x)) # No-copy conversion: http://(internal link) + if shard: + x = einops.rearrange(x, "(d l) ... -> d l ...", d=jax.local_device_count()) + if shard and put: # Only works for pmap (for now). + x = jax.device_put_sharded(list(x), jax.local_devices()) + return x + + +def start_input_pipeline(data, n_prefetch=1, shard=True): + fn = functools.partial(shard_and_put, shard=shard, put=n_prefetch) + it = (jax.tree.map(fn, elem) for elem in iter(data)) + return prefetch_iterator(it, n_prefetch) + + +def start_ragged_input_pipeline(data, n_prefetch=1, shard=True, ragged=None): + def maybe_shard_and_put(name, x): + return x if name in (ragged or {}) else shard_and_put(x, shard) + + it = (u.tree_map_with_names(maybe_shard_and_put, elem) for elem in iter(data)) + return prefetch_iterator(it, n_prefetch) diff --git a/big_vision/models/__init__.py b/big_vision/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/models/__pycache__/__init__.cpython-310.pyc b/big_vision/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/models/__pycache__/bit.cpython-310.pyc b/big_vision/models/__pycache__/bit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/models/__pycache__/common.cpython-310.pyc b/big_vision/models/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/models/__pycache__/vit.cpython-310.pyc b/big_vision/models/__pycache__/vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/models/bit.py b/big_vision/models/bit.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4235df9ec87549590396deb69739e310e6770d --- /dev/null +++ b/big_vision/models/bit.py @@ -0,0 +1,162 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ResNet V1 with GroupNorm.""" + +from typing import Optional, Sequence, Union + +from big_vision import utils +from big_vision.models import common +import flax +import flax.linen as nn +import flax.training.checkpoints +import jax.numpy as jnp +import numpy as np + + +def weight_standardize(w, axis, eps): + w = w - jnp.mean(w, axis=axis) + w = w / (jnp.std(w, axis=axis) + eps) + return w + + +class StdConv(nn.Conv): + + def param(self, name, *a, **kw): + param = super().param(name, *a, **kw) + if name == "kernel": + param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5) + return param + + +class ResidualUnit(nn.Module): + """Bottleneck ResNet block.""" + nmid: Optional[int] = None + strides: Sequence[int] = (1, 1) + + @nn.compact + def __call__(self, x): + nmid = self.nmid or x.shape[-1] // 4 + nout = nmid * 4 + + residual = x + if x.shape[-1] != nout or self.strides != (1, 1): + residual = StdConv(nout, (1, 1), self.strides, use_bias=False, + name="conv_proj")(residual) + residual = nn.GroupNorm(name="gn_proj")(residual) + + y = StdConv(nmid, (1, 1), use_bias=False, name="conv1")(x) + y = nn.GroupNorm(name="gn1")(y) + y = nn.relu(y) + y = StdConv(nmid, (3, 3), self.strides, use_bias=False, name="conv2")(y) + y = nn.GroupNorm(name="gn2")(y) + y = nn.relu(y) + y = StdConv(nout, (1, 1), use_bias=False, name="conv3")(y) + + y = nn.GroupNorm(name="gn3", scale_init=nn.initializers.zeros)(y) + y = nn.relu(residual + y) + return y + + +class ResNetStage(nn.Module): + """One stage of ResNet.""" + block_size: int + first_stride: Sequence[int] = (1, 1) + nmid: Optional[int] = None + + @nn.compact + def __call__(self, x): + x = ResidualUnit(self.nmid, strides=self.first_stride, name="unit1")(x) + for i in range(1, self.block_size): + x = ResidualUnit(self.nmid, name=f"unit{i + 1}")(x) + return x + + +class Model(nn.Module): + """ResNetV1.""" + num_classes: Optional[int] = None + width: float = 1 + depth: Union[int, Sequence[int]] = 50 + + @nn.compact + def __call__(self, image, *, train=False): + del train # Unused + blocks = get_block_desc(self.depth) + width = int(64 * self.width) + + out = {} + + # Root block + x = StdConv(width, (7, 7), (2, 2), use_bias=False, name="conv_root")(image) + x = nn.GroupNorm(name="gn_root")(x) + x = nn.relu(x) + x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") + out["stem"] = x + + # Stages + x = ResNetStage(blocks[0], nmid=width, name="block1")(x) + out["stage1"] = x + for i, block_size in enumerate(blocks[1:], 1): + x = ResNetStage(block_size, nmid=width * 2 ** i, + first_stride=(2, 2), name=f"block{i + 1}")(x) + out[f"stage{i + 1}"] = x + out["pre_logits_2d"] = x + + # Head + x = out["pre_logits"] = jnp.mean(x, axis=(1, 2)) + + if self.num_classes: + head = nn.Dense(self.num_classes, name="head", + kernel_init=nn.initializers.zeros) + out["logits_2d"] = head(out["pre_logits_2d"]) + x = out["logits"] = head(out["pre_logits"]) + + return x, out + + +# A dictionary mapping the number of layers in a resnet to the number of +# blocks in each stage of the model. +# NOTE: Does not include 18/34 as they also need non-bottleneck block! +def get_block_desc(depth): + if isinstance(depth, list): # Be robust to silly mistakes. + depth = tuple(depth) + return { + 26: [2, 2, 2, 2], # From timm, gets ~75% on ImageNet. + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + 200: [3, 24, 36, 3] + }.get(depth, depth) + + +def fix_old_checkpoints(params): + """Modifies params from old checkpoints to run with current implementation.""" + params = flax.core.unfreeze( + flax.training.checkpoints.convert_pre_linen(params)) + # Old linen used to store non-squeezed GN params. + params = flax.traverse_util.unflatten_dict({ + k: np.squeeze(v) if (set(k) + & {"gn_root", "gn_proj", "gn1", "gn2", "gn3"}) else v + for k, v in flax.traverse_util.flatten_dict(params).items() + }) + return params + + +def load(init_params, init_file, model_cfg, dont_load=()): + """Load init from checkpoint.""" + del model_cfg # Unused + params = utils.load_params(init_file) + params = common.merge_params(params, init_params, dont_load) + params = fix_old_checkpoints(params) + return params diff --git a/big_vision/models/bit_paper.py b/big_vision/models/bit_paper.py new file mode 100644 index 0000000000000000000000000000000000000000..26e5ba83616ce046a78d1a9b3fa32f8b4cbc1000 --- /dev/null +++ b/big_vision/models/bit_paper.py @@ -0,0 +1,260 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BiT models as in the paper (ResNet V2) w/ loading of public weights. + +See reproduction proof: http://(internal link)/qY70qs6j944 +""" + +import functools +import re +from typing import Optional, Sequence, Union + +from big_vision import utils as u +from big_vision.models import bit +from big_vision.models import common +import flax.linen as nn +import jax.numpy as jnp + + +def standardize(x, axis, eps): + x = x - jnp.mean(x, axis=axis, keepdims=True) + x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps) + return x + + +# Defined our own, because we compute normalizing variance slightly differently, +# which does affect performance when loading pre-trained weights! +class GroupNorm(nn.Module): + """Group normalization (arxiv.org/abs/1803.08494).""" + ngroups: int = 32 + + @nn.compact + def __call__(self, x): + + input_shape = x.shape + group_shape = x.shape[:-1] + (self.ngroups, x.shape[-1] // self.ngroups) + + x = x.reshape(group_shape) + + # Standardize along spatial and group dimensions + x = standardize(x, axis=[1, 2, 4], eps=1e-5) + x = x.reshape(input_shape) + + bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]]) + x = x * self.param('scale', nn.initializers.ones, bias_scale_shape) + x = x + self.param('bias', nn.initializers.zeros, bias_scale_shape) + return x + + +class StdConv(nn.Conv): + + def param(self, name, *a, **kw): + param = super().param(name, *a, **kw) + if name == 'kernel': + param = standardize(param, axis=[0, 1, 2], eps=1e-10) + return param + + +class RootBlock(nn.Module): + """Root block of ResNet.""" + width: int + + @nn.compact + def __call__(self, x): + x = StdConv(self.width, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], + use_bias=False, name='conv_root')(x) + x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)]) + return x + + +class ResidualUnit(nn.Module): + """Bottleneck ResNet block.""" + nmid: Optional[int] = None + strides: Sequence[int] = (1, 1) + + @nn.compact + def __call__(self, x): + nmid = self.nmid or x.shape[-1] // 4 + nout = nmid * 4 + conv = functools.partial(StdConv, use_bias=False) + + residual = x + x = GroupNorm(name='gn1')(x) + x = nn.relu(x) + + if x.shape[-1] != nout or self.strides != (1, 1): + residual = conv(nout, (1, 1), self.strides, name='conv_proj')(x) + + x = conv(nmid, (1, 1), name='conv1')(x) + x = GroupNorm(name='gn2')(x) + x = nn.relu(x) + x = conv(nmid, (3, 3), self.strides, padding=[(1, 1), (1, 1)], + name='conv2')(x) + x = GroupNorm(name='gn3')(x) + x = nn.relu(x) + x = conv(nout, (1, 1), name='conv3')(x) + + return x + residual + + +class ResNetStage(nn.Module): + """A stage (sequence of same-resolution blocks).""" + block_size: int + nmid: Optional[int] = None + first_stride: Sequence[int] = (1, 1) + + @nn.compact + def __call__(self, x): + out = {} + x = out['unit01'] = ResidualUnit( + self.nmid, strides=self.first_stride, name='unit01')(x) + for i in range(1, self.block_size): + x = out[f'unit{i+1:02d}'] = ResidualUnit( + self.nmid, name=f'unit{i+1:02d}')(x) + return x, out + + +class Model(nn.Module): + """ResNetV2.""" + num_classes: Optional[int] = None + width: int = 1 + depth: Union[int, Sequence[int]] = 50 # 50/101/152, or list of block depths. + head_zeroinit: bool = True + + @nn.compact + def __call__(self, image, *, train=False): + blocks = bit.get_block_desc(self.depth) + width = int(64 * self.width) + out = {} + + x = out['stem'] = RootBlock(width=width, name='root_block')(image) + + # Blocks + x, out['stage1'] = ResNetStage(blocks[0], nmid=width, name='block1')(x) + for i, block_size in enumerate(blocks[1:], 1): + x, out[f'stage{i + 1}'] = ResNetStage( + block_size, width * 2 ** i, + first_stride=(2, 2), name=f'block{i + 1}')(x) + + # Pre-head + x = out['norm_pre_head'] = GroupNorm(name='norm-pre-head')(x) + x = out['pre_logits_2d'] = nn.relu(x) + x = out['pre_logits'] = jnp.mean(x, axis=(1, 2)) + + # Head + if self.num_classes: + kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {} + head = nn.Dense(self.num_classes, name='head', **kw) + out['logits_2d'] = head(out['pre_logits_2d']) + x = out['logits'] = head(out['pre_logits']) + + return x, out + + +def load(init_params, init_file, model_cfg, dont_load=()): + """Loads the TF-dumped NumPy or big_vision checkpoint. + + Args: + init_params: random init params from which the new head is taken. + init_file: comes from `config.model_init`, can either be an absolute + path (ie starts with /) to the checkpoint, or a string like + "L-imagenet2012" describing one of the variants from the paper. + model_cfg: the model configuration. + dont_load: list of param names to be reset to init. + + Returns: + The loaded parameters. + """ + + # Support for vanity model names from the paper. + vanity = { + 'FunMatch-224px-i1k82.8': 'gs://bit_models/distill/R50x1_224.npz', + 'FunMatch-160px-i1k80.5': 'gs://bit_models/distill/R50x1_160.npz', + } + if init_file[0] in ('L', 'M', 'S'): # The models from the original paper. + # Supported names are of the following type: + # - 'M' or 'S': the original "upstream" model without fine-tuning. + # - 'M-ILSVRC2012': i21k model fine-tuned on i1k. + # - 'M-run0-caltech101': i21k model fine-tuned on VTAB's caltech101. + # each VTAB fine-tuning was run 3x, so there's run0, run1, run2. + if '-' in init_file: + up, down = init_file[0], init_file[1:] + else: + up, down = init_file, '' + down = {'-imagenet2012': '-ILSVRC2012'}.get(down, down) # normalize + fname = f'BiT-{up}-R{model_cfg.depth}x{model_cfg.width}{down}.npz' + fname = f'gs://bit_models/{fname}' + else: + fname = vanity.get(init_file, init_file) + + params = u.load_params(fname) + params = maybe_convert_big_transfer_format(params) + return common.merge_params(params, init_params, dont_load) + + +def maybe_convert_big_transfer_format(params_tf): + """If the checkpoint comes from legacy codebase, convert it.""" + + # Only do anything at all if we recognize the format. + if 'resnet' not in params_tf: + return params_tf + + # For ease of processing and backwards compatibility, flatten again: + params_tf = dict(u.tree_flatten_with_names(params_tf)[0]) + + # Works around some files containing weird naming of variables: + for k in list(params_tf): + k2 = re.sub('/standardized_conv2d_\\d+/', '/standardized_conv2d/', k) + if k2 != k: + params_tf[k2] = params_tf[k] + del params_tf[k] + + params = { + 'root_block': {'conv_root': {'kernel': params_tf[ + 'resnet/root_block/standardized_conv2d/kernel']}}, + 'norm-pre-head': { + 'bias': params_tf['resnet/group_norm/beta'][None, None, None], + 'scale': params_tf['resnet/group_norm/gamma'][None, None, None], + }, + 'head': { + 'kernel': params_tf['resnet/head/conv2d/kernel'][0, 0], + 'bias': params_tf['resnet/head/conv2d/bias'], + } + } + + for block in ('block1', 'block2', 'block3', 'block4'): + params[block] = {} + units = set([re.findall(r'unit\d+', p)[0] for p in params_tf.keys() + if p.find(block) >= 0]) + for unit in units: + params[block][unit] = {} + for i, group in enumerate('abc', 1): + params[block][unit][f'conv{i}'] = { + 'kernel': params_tf[f'resnet/{block}/{unit}/{group}/standardized_conv2d/kernel'] # pylint: disable=line-too-long + } + params[block][unit][f'gn{i}'] = { + 'bias': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/beta'][None, None, None], # pylint: disable=line-too-long + 'scale': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/gamma'][None, None, None], # pylint: disable=line-too-long + } + + projs = [p for p in params_tf.keys() + if p.find(f'{block}/{unit}/a/proj') >= 0] + assert len(projs) <= 1 + if projs: + params[block][unit]['conv_proj'] = { + 'kernel': params_tf[projs[0]] + } + + return params diff --git a/big_vision/models/common.py b/big_vision/models/common.py new file mode 100644 index 0000000000000000000000000000000000000000..175dfa77a1360bc2a0276fa12245c8d357b39406 --- /dev/null +++ b/big_vision/models/common.py @@ -0,0 +1,133 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities shared across models.""" + +from absl import logging +import big_vision.utils as u +import flax.linen as nn +import jax +import jax.numpy as jnp + + +def merge_params(loaded, inited, dont_load=(), match_dtype=False): + """Makes `loaded` pytree match `init`, warning or failing on mismatch. + + Args: + loaded: pytree of parameters, typically loaded from a checkpoint. + inited: pytree of parameter, typically coming from model init. + dont_load: List of regexes for parameters which shall not be taken + from `loaded`, either because they should remain at their init value, + or because they are missing on either side. + match_dtype: returned pytree as leaves converted to dtype from `inited`. + + Returns: + If successful, a new pytree which matches the structure of `init` + but contains values from `loaded`, except for `dont_load`. + + If structures don't match and mismatches are not covered by regexes in + `dont_load` argument, then raises an exception with more information. + """ + if inited is None: # A useful shortcut for example for colabs. + return loaded + + dont_load = u.check_and_compile_patterns(dont_load) + + def should_merge(name): + return not any(pattern.fullmatch(name) for pattern in dont_load) + + loaded_flat, _ = u.tree_flatten_with_names(loaded) + inited_flat, _ = u.tree_flatten_with_names(inited) + loaded_flat = {k: v for k, v in loaded_flat} + inited_flat = {k: v for k, v in inited_flat} + + # Let's first build the pytree from all common keys. + merged = {} + for name, init_val in inited_flat.items(): + # param is present in both. Load or ignore it! + if name in loaded_flat and should_merge(name): + merged[name] = loaded_flat[name] + if match_dtype: + merged[name] = loaded_flat[name].astype(init_val.dtype) + else: + logging.info("Ignoring checkpoint and using init value for %s", name) + merged[name] = init_val + + def pp(title, names, indent=" "): # Just pretty-printing + if names: + return f"{title}:\n" + "\n".join(f"{indent}{k}" for k in sorted(names)) + else: + return "" + + # Now, if there are keys that only exist in inited or loaded, be helpful: + not_in_loaded = inited_flat.keys() - loaded_flat.keys() + not_in_inited = loaded_flat.keys() - inited_flat.keys() + logging.info(pp("Parameters in model but not in checkpoint", not_in_loaded)) + logging.info(pp("Parameters in checkpoint but not in model", not_in_inited)) + + # And now see if any of them are not explicitly ignored => an error + not_in_loaded = {k for k in not_in_loaded if should_merge(k)} + not_in_inited = {k for k in not_in_inited if should_merge(k)} + + if not_in_loaded or not_in_inited: + raise ValueError( + pp("Params in checkpoint", loaded_flat.keys()) + "\n" + + pp("Params in model (code)", inited_flat.keys()) + "\n" + + pp("Params in model (code) but not in checkpoint and not `dont_load`ed", + not_in_loaded, indent=" - ") + "\n" + # Special indent for tests. + pp("Params in checkpoint but not in model (code) and not `dont_load`ed", + not_in_inited, indent=" + ")) # Special indent for tests. + + return u.recover_tree(merged.keys(), merged.values()) + + +class AddPositionEmbs(nn.Module): + """Adds positional embeddings to the inputs, supports caching for decode. + + Attributes: + decode: whether to run in single-position autoregressive mode. + """ + decode: bool = False + + @nn.compact + def __call__(self, inputs, posemb): + """Applies AddPositionEmbs module. + + Adds posemb to the inputs, supports single-position autoregressive mode. + + Args: + inputs: input data [batch_size, seq_len, emb_dim]. + posemb: positional embeddings. + + Returns: + output: inputs modulated by pos-embeddings [batch_size, seq_len, emb_dim]. + """ + assert inputs.ndim == 3, f"Unexpected inputs shape: {inputs.shape}" + _, seq_len, emb_dim = inputs.shape + pe = posemb[:, :seq_len, :] + + if self.decode: + is_initialized = self.has_variable("cache", "cache_index") + # We use a cache position index for tracking decoding position. + cache_index = self.variable("cache", "cache_index", + lambda: jnp.array(0, dtype=jnp.uint32)) + if is_initialized: + i = cache_index.value + cache_index.value = i + 1 + # Returns posemb[0, i, :], the positional embedding for the + # current decoding position. + pe = jax.lax.dynamic_slice(posemb, + start_indices=jnp.array((0, i, 0)), + slice_sizes=(1, 1, emb_dim)) + return inputs + pe diff --git a/big_vision/models/mlp_mixer.py b/big_vision/models/mlp_mixer.py new file mode 100644 index 0000000000000000000000000000000000000000..58bd4b99d21f061693da007b26dd24013e341851 --- /dev/null +++ b/big_vision/models/mlp_mixer.py @@ -0,0 +1,177 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MLP-Mixer model.""" + +from typing import Optional, Tuple +from absl import logging + +from big_vision import utils +from big_vision.models import common + +import einops +import flax.linen as nn +import flax.training.checkpoints +import jax +import jax.numpy as jnp + + +class MlpBlock(nn.Module): + mlp_dim: int + + @nn.compact + def __call__(self, x): + y = nn.Dense(self.mlp_dim)(x) + y = nn.gelu(y) + return nn.Dense(x.shape[-1])(y) + + +class MixerBlock(nn.Module): + """Mixer block layer.""" + tokens_mlp_dim: int + channels_mlp_dim: int + drop_p: float + + @nn.compact + def __call__(self, x, *, train=False): + y = nn.LayerNorm()(x) + y = jnp.swapaxes(y, 1, 2) + y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y) + y = jnp.swapaxes(y, 1, 2) + x = x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng) + y = nn.LayerNorm()(x) + y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y) + return x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng) + + +class MlpMixer(nn.Module): + """Mixer architecture.""" + patch_size: Tuple[int, int] + num_classes: Optional[int] + num_blocks: int + hidden_dim: int + tokens_mlp_dim: int + channels_mlp_dim: int + model_name: Optional[str] = None + stoch_depth: float = 0.0 + + @nn.compact + def __call__(self, image, *, train=False): + out = {} + x = out["stem"] = nn.Conv(self.hidden_dim, self.patch_size, + strides=self.patch_size, name="stem")(image) + x = out["input_tokens"] = einops.rearrange(x, "n h w c -> n (h w) c") + for i in range(self.num_blocks): + drop_p = (i / max(self.num_blocks - 1, 1)) * self.stoch_depth + x = out[f"block_{i}"] = MixerBlock( + self.tokens_mlp_dim, self.channels_mlp_dim, drop_p)(x, train=train) + x = nn.LayerNorm(name="pre_head_layer_norm")(x) + x = out["pre_logits"] = jnp.mean(x, axis=1) + if self.num_classes: + x = out["logits"] = nn.Dense( + self.num_classes, kernel_init=nn.initializers.zeros, name="head")(x) + return x, out + + +def Model(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name + """Factory function to easily create a Model variant like "L/16".""" + + if variant is not None: + model_size, patch = variant.split("/") + kw.setdefault("patch_size", (int(patch), int(patch))) + config = { + "S": { + "hidden_dim": 512, + "num_blocks": 8, + "channels_mlp_dim": 2048, + "tokens_mlp_dim": 256 + }, + "B": { + "hidden_dim": 768, + "num_blocks": 12, + "channels_mlp_dim": 3072, + "tokens_mlp_dim": 384 + }, + "L": { + "hidden_dim": 1024, + "num_blocks": 24, + "channels_mlp_dim": 4096, + "tokens_mlp_dim": 512 + }, + "H": { + "hidden_dim": 1280, + "num_blocks": 32, + "channels_mlp_dim": 5120, + "tokens_mlp_dim": 640 + }, + }[model_size] + + for k, v in config.items(): + kw.setdefault(k, v) + + logging.info("Mixer config: %s", kw) + return MlpMixer(num_classes=num_classes, **kw) + + +def load(init_params, init_file, model_cfg, dont_load=()): + """Load checkpoint.""" + + del model_cfg + # Shortcut names for some canonical paper checkpoints: + init_file = { + # pylint: disable=line-too-long + # Pretrained models from the MLP-Mixer paper: https://arxiv.org/abs/2105.01601. + "B-i1k/16": "gs://mixer_models/imagenet1k/Mixer-B_16.npz", + "L-i1k/16": "gs://mixer_models/imagenet1k/Mixer-L_16.npz", + "B-i21k/16": "gs://mixer_models/imagenet21k/Mixer-B_16.npz", + "L-i21k/16": "gs://mixer_models/imagenet21k/Mixer-L_16.npz", + # pylint: enable=line-too-long + }.get(init_file, init_file) + restored_params = utils.load_params(init_file) + restored_params = flax.training.checkpoints.convert_pre_linen(restored_params) + + if "Mixer" in restored_params: + restored_params["pre_head_layer_norm"] = restored_params["Mixer"].pop( + "encoder_norm" + ) + restored_params["stem"] = restored_params.pop("embedding") + def unflatten_dense(d): + return { + "Dense_0": { + "bias": d["bias1"].squeeze(), + "kernel": d["kernel1"].squeeze(), + }, + "Dense_1": { + "bias": d["bias2"].squeeze(), + "kernel": d["kernel2"].squeeze(), + }, + } + for k, v in restored_params["Mixer"].items(): + assert k.startswith("encoderblock_"), k + v["token_mixing"] = unflatten_dense(v.pop("token_mixing_phase_0")) + v["channel_mixing"] = unflatten_dense(v.pop("channel_mixing_phase_0")) + restored_params["MixerBlock_" + k[len("encoderblock_"):]] = v + del restored_params["Mixer"] + + # possibly use the random init for some of the params (such as, the head). + restored_params = common.merge_params(restored_params, init_params, dont_load) + + return restored_params + + +def _stoch_depth_mask(x, drop_p, deterministic, make_rng): + if not deterministic and drop_p: + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + return 1.0 - jax.random.bernoulli(make_rng("dropout"), drop_p, shape) + return 1.0 diff --git a/big_vision/models/ppp/__init__.py b/big_vision/models/ppp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/models/ppp/gemma.py b/big_vision/models/ppp/gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/models/vit.py b/big_vision/models/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..7f536cf3b17ee97ec0cac3cc1f282dc0a8a699c3 --- /dev/null +++ b/big_vision/models/vit.py @@ -0,0 +1,480 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A refactored and simplified ViT. + +However, the names of modules are made to match the old ones for easy loading. +""" + +from typing import Optional, Sequence, Union + +from absl import logging +from big_vision import utils +from big_vision.models import common +import flax +import flax.linen as nn +import flax.training.checkpoints +import jax +import jax.numpy as jnp +import numpy as np +import scipy.ndimage + + +def posemb_sincos_2d(h, w, width, temperature=10_000., dtype=jnp.float32): + """Follows the MoCo v3 logic.""" + y, x = jnp.mgrid[:h, :w] + + assert width % 4 == 0, "Width must be mult of 4 for sincos posemb" + omega = jnp.arange(width // 4) / (width // 4 - 1) + omega = 1. / (temperature**omega) + y = jnp.einsum("m,d->md", y.flatten(), omega) + x = jnp.einsum("m,d->md", x.flatten(), omega) + pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) + return jnp.asarray(pe, dtype)[None, :, :] + + +def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32): + if typ == "learn": + return self.param(name, nn.initializers.normal(stddev=1/np.sqrt(width)), + (1, np.prod(seqshape), width), dtype) + elif typ == "sincos2d": + return posemb_sincos_2d(*seqshape, width, dtype=dtype) + else: + raise ValueError(f"Unknown posemb type: {typ}") + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim + dropout: float = 0.0 + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, x, deterministic=True): + """Applies Transformer MlpBlock module.""" + inits = dict( + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.normal(stddev=1e-6), + ) + + d = x.shape[-1] + x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x) + # In some extreme batch-size cases, this is needed as of Sept 2024: + x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb")) + x = nn.gelu(x) + x = nn.Dropout(rate=self.dropout)(x, deterministic) + x = nn.Dense(d, dtype=self.dtype_mm, **inits)(x) + return x + + +class Encoder1DBlock(nn.Module): + """Single transformer encoder block (MHSA + MLP).""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + dropout: float = 0.0 + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, x, deterministic=True): + out = {} + x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb")) + y = nn.LayerNorm()(x) + y = out["sa"] = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=deterministic, + dtype=self.dtype_mm, + )(y, y) + y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb")) + y = nn.Dropout(rate=self.dropout)(y, deterministic) + x = out["+sa"] = x + y + + y = nn.LayerNorm()(x) + y = out["mlp"] = MlpBlock( + mlp_dim=self.mlp_dim, dropout=self.dropout, + dtype_mm=self.dtype_mm, + )(y, deterministic) + y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb")) + y = nn.Dropout(rate=self.dropout)(y, deterministic) + x = out["+mlp"] = x + y + x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb")) + return x, out + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + depth: int + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + dropout: float = 0.0 + scan: bool = False + remat_policy: str = "nothing_saveable" + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, x, deterministic=True): + out = {} + + if self.scan: + block = nn.remat( + Encoder1DBlock, + prevent_cse=False, + static_argnums=(2,), # 0=self, 2=deterministic + policy=getattr(jax.checkpoint_policies, self.remat_policy, None), + ) + x, scan_out = nn.scan( + block, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=nn.broadcast, + length=self.depth)( + name="encoderblock", + dtype_mm=self.dtype_mm, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout)(x, deterministic) + for lyr in range(self.depth): + out[f"block{lyr:02d}"] = jax.tree.map(lambda o, l=lyr: o[l], scan_out) + else: + # Input Encoder + for lyr in range(self.depth): + block_cur = Encoder1DBlock( + name=f"encoderblock_{lyr}", + dtype_mm=self.dtype_mm, + mlp_dim=self.mlp_dim, num_heads=self.num_heads, + dropout=self.dropout) + x, out[f"block{lyr:02d}"] = block_cur(x, deterministic) + out["pre_ln"] = x # Alias for last block, but without the number in it. + + return nn.LayerNorm(name="encoder_norm")(x), out + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + + @nn.compact + def __call__(self, x): + # TODO + n, l, d = x.shape # pylint: disable=unused-variable + probe = self.param("probe", nn.initializers.xavier_uniform(), + (1, 1, d), x.dtype) + probe = jnp.tile(probe, [n, 1, 1]) + + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform())(probe, x) + + # TODO: dropout on head? + y = nn.LayerNorm()(x) + x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) + return x[:, 0] + + +class _Model(nn.Module): + """ViT model.""" + + num_classes: Optional[int] = None + patch_size: Sequence[int] = (16, 16) + width: int = 768 + depth: int = 12 + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + posemb: str = "learn" # Can also be "sincos2d" + rep_size: Union[int, bool] = False + dropout: float = 0.0 + pool_type: str = "gap" # Can also be "map" or "tok" + head_zeroinit: bool = True + scan: bool = False + # or "dots_with_no_batch_dims_saveable" for more speed (memory costly) + remat_policy: str = "nothing_saveable" + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, image, *, train=False): + out = {} + + image = jnp.asarray(image, self.dtype_mm) + + # Patch extraction + x = out["stem"] = nn.Conv( + self.width, self.patch_size, strides=self.patch_size, + padding="VALID", name="embedding", dtype=self.dtype_mm)(image) + + n, h, w, c = x.shape + x = jnp.reshape(x, [n, h * w, c]) + + # Add posemb before adding extra token. + x = out["with_posemb"] = x + get_posemb( + self, self.posemb, (h, w), c, "pos_embedding", x.dtype) + + if self.pool_type == "tok": + cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype) + x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1) + + n, l, c = x.shape # pylint: disable=unused-variable + x = nn.Dropout(rate=self.dropout)(x, not train) + + x, out["encoder"] = Encoder( + depth=self.depth, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout, + scan=self.scan, + remat_policy=self.remat_policy, + dtype_mm=self.dtype_mm, + name="Transformer")( + x, deterministic=not train) + encoded = out["encoded"] = x + + if self.pool_type == "map": + x = out["head_input"] = MAPHead( + num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) + elif self.pool_type == "gap": + x = out["head_input"] = jnp.mean(x, axis=1) + elif self.pool_type == "0": + x = out["head_input"] = x[:, 0] + elif self.pool_type == "tok": + x = out["head_input"] = x[:, 0] + encoded = encoded[:, 1:] + elif self.pool_type == "none": + pass + else: + raise ValueError(f"Unknown pool type: '{self.pool_type}'") + + x_2d = jnp.reshape(encoded, [n, h, w, -1]) + + if self.rep_size: + rep_size = self.width if self.rep_size is True else self.rep_size + hid = nn.Dense(rep_size, name="pre_logits") + # NOTE: In the past we did not include tanh in pre_logits. + # For few-shot, it should not matter much, as it whitens anyways. + x_2d = nn.tanh(hid(x_2d)) + x = nn.tanh(hid(x)) + + out["pre_logits_2d"] = x_2d + out["pre_logits"] = x + + if self.num_classes: + kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} + head = nn.Dense(self.num_classes, name="head", **kw) + x_2d = out["logits_2d"] = head(x_2d) + x = out["logits"] = head(x) + + return x, out + + +def Model(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name + """Factory function, because linen really don't like what I'm doing!""" + return _Model(num_classes, **{**decode_variant(variant), **kw}) + + +def decode_variant(variant): + """Converts a string like "B" or "B/32" into a params dict.""" + if variant is None: + return {} + + v, patch = variant, {} + if "/" in variant: + v, patch = variant.split("/") + patch = {"patch_size": (int(patch), int(patch))} + + return { + # pylint:disable=line-too-long + # Reference: Table 2 of https://arxiv.org/abs/2106.04560. + "width": {"mu": 32, "Ti": 192, "S": 384, "M": 512, "B": 768, "L": 1024, "So400m": 1152, "H": 1280, "g": 1408, "g-opt": 1536, "G": 1664, "G-opt": 1536, "e": 1792}[v], + "depth": {"mu": 1, "Ti": 12, "S": 12, "M": 12, "B": 12, "L": 24, "So400m": 27, "H": 32, "g": 40, "g-opt": 40, "G": 48, "G-opt": 48, "e": 56}[v], + "mlp_dim": {"mu": 128, "Ti": 768, "S": 1536, "M": 2048, "B": 3072, "L": 4096, "So400m": 4304, "H": 5120, "g": 6144, "g-opt": 6144, "G": 8192, "G-opt": 8192, "e": 15360}[v], + "num_heads": {"mu": 2, "Ti": 3, "S": 6, "M": 8, "B": 12, "L": 16, "So400m": 16, "H": 16, "g": 16, "g-opt": 16, "G": 16, "G-opt": 16, "e": 16}[v], + # pylint:enable=line-too-long + **patch + } + + +def resample_posemb(old, new): + """This function implements "high-res finetuning" for transformer models.""" + # Rescale the grid of position embeddings. Param shape is (1,N,1024) + if old.shape == new.shape: + return old + + logging.info("ViT: resize %s to %s", old.shape, new.shape) + gs_old = int(np.sqrt(old.shape[1])) + gs_new = int(np.sqrt(new.shape[1])) + logging.info("ViT: grid-size from %s to %s", gs_old, gs_new) + grid = old.reshape(gs_old, gs_old, -1) + + zoom = (gs_new/gs_old, gs_new/gs_old, 1) + grid = scipy.ndimage.zoom(grid, zoom, order=1) + grid = grid.reshape(1, gs_new*gs_new, -1) + return grid + + +def fix_old_checkpoints(params): + """Fix small bwd incompat that can't be resolved with names in model def.""" + + params = flax.core.unfreeze( + flax.training.checkpoints.convert_pre_linen(params)) + + # Original ViT paper variant had posemb in a module: + if "posembed_input" in params["Transformer"]: + logging.info("ViT: Loading and fixing VERY old posemb") + posemb = params["Transformer"].pop("posembed_input") + params["pos_embedding"] = posemb["pos_embedding"] + + # Widely used version before 2022 had posemb in Encoder: + if "pos_embedding" in params["Transformer"]: + logging.info("ViT: Loading and fixing old posemb") + params["pos_embedding"] = params["Transformer"].pop("pos_embedding") + + # Old vit.py used to first concat [cls] token, then add posemb. + # This means a B/32@224px would have 7x7+1 posembs. This is useless and clumsy + # so we changed to add posemb then concat [cls]. We can recover the old + # checkpoint by manually summing [cls] token and its posemb entry. + if "pos_embedding" in params: + pe = params["pos_embedding"] + if int(np.sqrt(pe.shape[1])) ** 2 + 1 == int(pe.shape[1]): + logging.info("ViT: Loading and fixing combined cls+posemb") + pe_cls, params["pos_embedding"] = pe[:, :1], pe[:, 1:] + if "cls" in params: + params["cls"] += pe_cls + + # MAP-head variants during ViT-G development had it inlined: + if "probe" in params: + params["MAPHead_0"] = { + k: params.pop(k) for k in + ["probe", "MlpBlock_0", "MultiHeadDotProductAttention_0", "LayerNorm_0"] + } + + return params + + +def pyloop_to_scan(params_pyloop): + """Converts a python for-loop ViT checkpoint to a lax.scan based one.""" + # On a high level, they are the same except that the for loop has separate + # array pytrees for each encoderblock, while the scan one has just one + # encoderblock pytree, with all block's params concatenated. + + params_scan = jax.tree.map(lambda x: x, params_pyloop) # Structural copy + t = params_scan["Transformer"] + + # Find highest index of encoderblocks in the checkpoint (they start at 0): + encoderblocks = {k for k in t if k.startswith("encoderblock_")} + depth = 1 + max({int(k.split("_")[-1]) for k in encoderblocks}) + + def stack(*values): + return np.stack(values) + + # Stack all encoderblocks into a single one: + t["encoderblock"] = jax.tree.map( + stack, *[t[f"encoderblock_{lyr}"] for lyr in range(depth)]) + + for lyr in range(depth): + del t[f"encoderblock_{lyr}"] + + return params_scan + + +def scan_to_pyloop(params_scan): + """Converts a lax.scan ViT checkpoint to a python for-loop based one.""" + # See comment in pyloop_to_scan. + + params_scan = jax.tree.map(lambda x: x, params_scan) # Structural copy + t = params_scan["Transformer"] + + # Find out how many encoderblocks there are + depth = len(t["encoderblock"]["LayerNorm_0"]["bias"]) + + # Create that many encoderblocks, each with their slice of their sub-pytree. + for lyr in range(depth): + block = jax.tree.map(lambda x, lyr=lyr: x[lyr], t["encoderblock"]) + t[f"encoderblock_{lyr}"] = block + + del t["encoderblock"] + return params_scan + + +def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name because we had to CamelCase above. + """Load init from checkpoint, both old model and this one. +Hi-res posemb.""" + init_file = VANITY_NAMES.get(init_file, init_file) + restored_params = utils.load_params(init_file) + + restored_params = fix_old_checkpoints(restored_params) + + # Detect attempts to load non-scan checkpoint into scan model. + if (model_cfg.get("scan") and + "encoderblock" not in restored_params["Transformer"]): + restored_params = pyloop_to_scan(restored_params) + if (not model_cfg.get("scan") + and "encoderblock" in restored_params["Transformer"]): + restored_params = scan_to_pyloop(restored_params) + + # possibly use the random init for some of the params (such as, the head). + restored_params = common.merge_params(restored_params, init_params, dont_load) + + # resample posemb if needed. + # TODO: Take this from model_cfg to avoid need for init_params. + if init_params and "pos_embedding" in init_params: + restored_params["pos_embedding"] = resample_posemb( + old=restored_params["pos_embedding"], + new=init_params["pos_embedding"]) + + return restored_params + + +# Shortcut names for some canonical paper checkpoints: +VANITY_NAMES = { + # pylint: disable=line-too-long + # Recommended models from https://arxiv.org/abs/2106.10270 + # Many more models at https://github.com/google-research/vision_transformer + "howto-i21k-Ti/16": "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", + "howto-i21k-S/32": "gs://vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-S/16": "gs://vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", + "howto-i21k-B/32": "gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-B/16": "gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-B/8": "gs://vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-L/16": "gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz", + + # Better plain vit-s16 baselines from https://arxiv.org/abs/2205.01580 + "i1k-s16-90ep": "gs://big_vision/vit_s16_i1k_90ep.npz", + "i1k-s16-150ep": "gs://big_vision/vit_s16_i1k_150ep.npz", + "i1k-s16-300ep": "gs://big_vision/vit_s16_i1k_300ep.npz", + + # DeiT-3 checkpoints from https://github.com/facebookresearch/deit/blob/main/README_revenge.md + # First layer converted to take inputs in [-1,1] + "deit3_S_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_small_224_1k.npz", + "deit3_S_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_small_224_21k.npz", + "deit3_S_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_small_384_1k.npz", + "deit3_S_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_small_384_21k.npz", + "deit3_B_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_base_224_1k.npz", + "deit3_B_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_base_224_21k.npz", + "deit3_B_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_base_384_1k.npz", + "deit3_B_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_base_384_21k.npz", + "deit3_L_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_large_224_1k.npz", + "deit3_L_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_224_21k.npz", + "deit3_L_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_1k.npz", + "deit3_L_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_21k.npz", + + # SigLIP image encoder checkpoints from https://arxiv.org/abs/2303.15343 + "SigLIP B/16 224": "gs://big_vision/siglip/webli_en_b16_224_63724782.npz:img", + "SigLIP B/16 256": "gs://big_vision/siglip/webli_en_b16_256_60500360.npz:img", + "SigLIP B/16 384": "gs://big_vision/siglip/webli_en_b16_384_68578854.npz:img", + "SigLIP B/16 512": "gs://big_vision/siglip/webli_en_b16_512_68580893.npz:img", + "SigLIP L/16 256": "gs://big_vision/siglip/webli_en_l16_256_60552751.npz:img", + "SigLIP L/16 384": "gs://big_vision/siglip/webli_en_l16_384_63634585.npz:img", + "SigLIP So400m/14 224": "gs://big_vision/siglip/webli_en_so400m_224_57633886.npz:img", + "SigLIP So400m/14 384": "gs://big_vision/siglip/webli_en_so400m_384_58765454.npz:img", + "SigLIP B/16-i18n 256": "gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz:img", + # pylint: enable=line-too-long +} diff --git a/big_vision/optax.py b/big_vision/optax.py new file mode 100644 index 0000000000000000000000000000000000000000..39ddb0fc3d8075985a75d7cdf150d430e141d681 --- /dev/null +++ b/big_vision/optax.py @@ -0,0 +1,225 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gradient transformations and other optax utilities.""" + +import operator +import big_vision.utils as u +import jax +import jax.numpy as jnp +import optax + + +def find_states(opt_state, cls): + leaves = jax.tree.leaves( + opt_state, is_leaf=lambda node: isinstance(node, cls)) + return [leaf for leaf in leaves if isinstance(leaf, cls)] + + +def get_count(opt_state, jittable=False): + """Returns `ScaleByScheduleState.count` from `opt_state` as an integer.""" + counts = [ + state.count + for state in find_states(opt_state, optax.ScaleByScheduleState) + ] + if jittable: + return counts[0] + else: + counts = {int(c) for c in counts} + assert len(counts) == 1, f"Expected exactly 1 ScaleByScheduleState:{counts}" + return next(iter(counts)) + + +def replace_frozen(schedule, pytree, replacement, log=None): + """Replaces values matching frozen params in `pytree` with `replacement`.""" + if not isinstance(schedule, (list, tuple)): + return pytree + masks, scheds = _make_mask_trees(pytree, schedule, log=log) + frozen_mask, _, _ = _split_frozen(masks, scheds) + return jax.tree.map( + lambda v, f: replacement if f else v, pytree, frozen_mask) + + +def clip_by_per_example_global_norm( + max_norm: float, +) -> optax.GradientTransformation: + """Clips the norm of per-example gradients.""" + + def init_fn(params): + del params + return optax.EmptyState() + + def update_fn(updates, state, params=None): + del params + grads_flat, grads_treedef = jax.tree_util.tree_flatten(updates) + batch_size = grads_flat[0].shape[0] + clipped, _ = optax.per_example_global_norm_clip(grads_flat, max_norm) + grads_sum = jax.tree_util.tree_unflatten(grads_treedef, clipped) + grads_mean = jax.tree_util.tree_map(lambda x: x / batch_size, grads_sum) + return grads_mean, state + + return optax.GradientTransformation(init_fn, update_fn) + + +def make(config, params, *, sched_kw): + """Returns gradient transform and learning rate functions.""" + + # Global schedule. No schedule means frozen. + schedule = config.get("schedule", {}) + if not isinstance(schedule, (tuple, list)): + schedule = [(".*", schedule)] + masks, scheds = _make_mask_trees(params, schedule, "config.schedule") + frozen_mask, masks, scheds = _split_frozen(masks, scheds) + not_frozen_mask = jax.tree.map(operator.not_, frozen_mask) + def create_schedule(mult=1.0, **kw): + assert "base" not in kw, kw + return u.create_learning_rate_schedule(base=mult, **kw) + schedule_fns = [create_schedule(**sched_kw, **sched) for sched in scheds] + schedule_txs = [ + optax.masked(optax.scale_by_schedule(schedule_fn), mask) + for schedule_fn, mask in zip(schedule_fns, masks) + ] + [ + # Removes weight decay updates. Note that weight decay already has an + # independent mask (which cannot be combined easily with a second mask), + # so instead we multiply updates for frozen params with zero. + optax.masked(optax.set_to_zero(), frozen_mask) + ] + + # Gradient clipping. + if clip_norm := config.get("grad_clip_norm"): + if config.get("grad_clip_per_example"): + clip_tx = clip_by_per_example_global_norm(clip_norm) + else: + clip_tx = optax.clip_by_global_norm(clip_norm) + grad_clip_norm_tx = optax.masked(clip_tx, not_frozen_mask) + else: + grad_clip_norm_tx = optax.identity() + + # Optimizer updates. + tx_func = operator.attrgetter(config.optax_name)(optax) + opt_txs = [optax.masked(tx_func(**config.get("optax", {})), not_frozen_mask)] + assert "optim" not in config, "Deprecated option, use config.optax." + + # Learning rate multipliers. Defaults to 1.0. + lr_mult_txs = [optax.scale(config.lr)] + if config.get("lr_mults"): + masks, mults = _make_mask_trees(params, config.lr_mults, "config.lr_mults") + assert all(mult > 0 for mult in mults), ( + f"Use schedule=None for parameter freezing instead of lr_mults={mults}") + lr_mult_txs += [ + optax.masked(optax.scale(mult), mask) + for mult, mask in zip(mults, masks) + ] + + # Weight decay. Defaults to 0.0. + # Weight decay is not gradient-based but instead uses "params side-input". + # Hence, weight decay is additive and independent of previous gradient-based + # updates. + assert "weight_decay" not in config, "Deprecated option. Use wd and schedule." + assert config.get("weight_decay_decouple", True), ( + "Coupled weight decay not supported anymore.") + if config.get("wd"): + wd_mults = config.get("wd_mults", [(".*/kernel$", 1.0)]) + masks, mults = _make_mask_trees(params, wd_mults, "config.wd_mults") + weight_decay_txs = [ + optax.add_decayed_weights(config.wd * mult, mask) + for mult, mask in zip(mults, masks) + ] + else: + weight_decay_txs = [] + + # Combine gradient updates and learning rate schedules. + return optax.chain( + grad_clip_norm_tx, + *opt_txs, + *lr_mult_txs, + *weight_decay_txs, + *schedule_txs, + optax.scale(-1.0)), schedule_fns + + +def _make_mask_trees(params, patterns_values, log): + patterns, values = zip(*patterns_values) + masks = u.make_mask_trees(params, patterns, log=log) + return masks, values + + +def _split_frozen(masks, scheds): + """Computes `frozen_mask` and updates `masks` and `scheds`.""" + # Specifying `None` as a scheduler freezes params. + all_false = jax.tree.map(lambda *bools: not any(bools), *masks) + not_covered = [k for k, v in u.tree_flatten_with_names(all_false)[0] if v] + assert not not_covered, ( + f"All params must be covered (use `None` for freezing): {not_covered}") + frozen_masks = [ + mask for mask, sched in zip(masks, scheds) if sched is None] + frozen_mask = jax.tree.map( + lambda *bools: any(bools), *frozen_masks, + all_false) # `all_false` is required when `frozen_masks==[]`. + masks, scheds = zip(*( + (mask, sched) for mask, sched in zip(masks, scheds) if sched is not None)) + return frozen_mask, masks, scheds + + +############ Custom BigVision optimizers ####################################### +# Currently there's only one custom optimizer and we don't foresee new ones in +# the near future, we opt not to create a new optimizer folder/module for just +# one isolated case. If there will be more optimizers, we can consider moving +# them into individual files in a subfolder. + + +# A dummy object to allow for foo.bar access syntax, see +# https://stackoverflow.com/a/19476841/2366315 +optax.big_vision = type("", (), {})() + + +def scale_by_adafactor(min_dim_size_to_factor=32, + decay_rate=0.8, decay_offset=0, + beta2_cap=0.999, + clipping_threshold=None, + momentum=0.9, dtype_momentum=jnp.bfloat16, + eps=1e-30): + """The BigVision variant of Adafactor optimizer.""" + + def _decay_rate_pow(i, exponent): + """Second-order moment decay schedule.""" + t = jnp.array(i, jnp.float32) + 1.0 + return jnp.minimum(beta2_cap, 1.0 - t**(-exponent)) + + scale_by_rms = optax.scale_by_factored_rms( + factored=True, + decay_rate=decay_rate, + step_offset=decay_offset, + min_dim_size_to_factor=min_dim_size_to_factor, + epsilon=eps, + decay_rate_fn=_decay_rate_pow) + + clip = (optax.clip_by_block_rms(clipping_threshold) if clipping_threshold + else optax.identity()) + + mom = (optax.ema(momentum, debias=False, accumulator_dtype=dtype_momentum) + if momentum else optax.identity()) + + return optax.chain(scale_by_rms, clip, mom) + +optax.big_vision.scale_by_adafactor = scale_by_adafactor # pytype: disable=module-attr + + +# A few more aliases we use frequently: +def momentum_hp(momentum=0.9, dtype=jnp.bfloat16, nesterov=False): + """SGD-Momentum with half-precision accumulator.""" + return optax.trace(decay=momentum, accumulator_dtype=dtype, nesterov=nesterov) + +optax.big_vision.momentum_hp = momentum_hp # pytype: disable=module-attr +optax.big_vision.sgd = optax.identity # pytype: disable=module-attr diff --git a/big_vision/optax_test.py b/big_vision/optax_test.py new file mode 100644 index 0000000000000000000000000000000000000000..86f7bd9999079b393565ad5a718b4c1dbd815e79 --- /dev/null +++ b/big_vision/optax_test.py @@ -0,0 +1,341 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for optax.""" + +from absl.testing import absltest +from absl.testing import parameterized +from big_vision import optax as bv_optax +import chex +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np +import optax + + +class OptaxTest(parameterized.TestCase): + + def test_get_count(self): + params = jax.tree.map(jnp.array, {"a": 1.}) + tx = optax.masked( + optax.scale_by_schedule(lambda step: step), + {"a": True}, + ) + opt_state = tx.init(params) + self.assertEqual(bv_optax.get_count(opt_state), 0) + _, opt_state = tx.update(params, opt_state) + self.assertEqual(bv_optax.get_count(opt_state), 1) + + def test_split_frozen(self): + params = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 1., "bias": 2.}, + }) # pyformat: disable + sched1 = dict(decay_type="cosine") + sched2 = dict(decay_type="linear") + schedule = [ + (".*/kernel", sched1), + (".*/bias", sched2), + ] + masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule") + frozen_mask, masks, scheds = bv_optax._split_frozen(masks, scheds) + chex.assert_trees_all_equal( + frozen_mask, + {"Dense_0": {"kernel": False, "bias": False}}, + ) # pyformat: disable + chex.assert_trees_all_equal( + masks, + ( + {"Dense_0": {"kernel": True, "bias": False}}, + {"Dense_0": {"kernel": False, "bias": True}}, + ), + ) # pyformat: disable + self.assertEqual(scheds, (sched1, sched2)) + # freeze some + schedule = [ + (".*/bias", None), + ("Dense_0/.*", sched1), + (".*", None), + ] + masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule") + frozen_mask, masks, scheds = bv_optax._split_frozen(masks, scheds) + chex.assert_trees_all_equal( + frozen_mask, + {"Dense_0": {"kernel": False, "bias": True}}, + ) # pyformat: disable + chex.assert_trees_all_equal( + masks, + ({"Dense_0": {"kernel": True, "bias": False}},), + ) # pyformat: disable + self.assertEqual(scheds, (sched1,)) + # does not cover all params - fails + schedule = [ + (".*/kernel", None), + ] + masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule") + with self.assertRaisesRegex(AssertionError, "All params must be covered"): + _ = bv_optax._split_frozen(masks, scheds) + + def test_replace_frozen(self): + params = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 1., "bias": 2.}, + }) # pyformat: disable + schedule = [ + (".*/kernel", {}), + (".*", None), + ] + chex.assert_trees_all_equal( + bv_optax.replace_frozen(schedule, params, 0.), + {"Dense_0": {"kernel": 1., "bias": 0.}}, + ) # pyformat: disable + + def test_make_simple(self): + params = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 1., "bias": 2.}, + }) # pyformat: disable + + config = ml_collections.ConfigDict() + config.lr = 0.01 + config.schedule = dict(decay_type="linear") + config.optax_name = "scale" + config.optax = ml_collections.ConfigDict() + g_scale = 0.5 + config.optax.step_size = g_scale + + total_steps = 10 + sched_kw = dict(global_batch_size=1, total_steps=total_steps) + tx, (schedule_fn,) = bv_optax.make(config, params, sched_kw=sched_kw) + opt_state = tx.init(params) + grads = jax.tree.map(jnp.ones_like, params) + for step in range(total_steps): + updates, opt_state = tx.update(grads, opt_state) + self.assertEqual(bv_optax.get_count(opt_state), step + 1) + sched = schedule_fn(step) + np.testing.assert_almost_equal( + sched, 1.0 / total_steps * (total_steps - step)) + make_tx = lambda sched: lambda g: -sched * config.lr * g_scale * g + chex.assert_trees_all_close(updates, jax.tree.map(make_tx(sched), grads)) + + def test_make_wd(self): + params = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 1., "bias": 2., "other": 3.}, + }) # pyformat: disable + wds = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 2e-3, "bias": 5e-4, "other": 0.}, + }) # pyformat: disable + + config = ml_collections.ConfigDict() + config.lr = 0.01 + config.wd = 1e-3 + config.wd_mults = [ + (".*/kernel", 2.0), + (".*/bias", 0.5), + ] + config.schedule = dict(decay_type="linear") + config.optax_name = "scale" + config.optax = ml_collections.ConfigDict() + g_scale = 0.5 + config.optax.step_size = g_scale + + total_steps = 10 + sched_kw = dict(global_batch_size=1, total_steps=total_steps) + tx, (sched_fn,) = bv_optax.make(config, params, sched_kw=sched_kw) + opt_state = tx.init(params) + grads = jax.tree.map(jnp.ones_like, params) + for step in range(total_steps): + updates, opt_state = tx.update(grads, opt_state, params) + self.assertEqual(bv_optax.get_count(opt_state), step + 1) + sched = sched_fn(step) + np.testing.assert_almost_equal( + sched, 1.0 / total_steps * (total_steps - step)) + + def make_tx(sched): + def inner(p, g, wd): + return -sched * (config.lr * g_scale * g + p * wd) + return inner + + chex.assert_trees_all_close( + updates, jax.tree.map(make_tx(sched), params, grads, wds)) + + def test_make_clip_norm(self): + params = jax.tree.map(jnp.array, { + "Dense_0": {"kernel": 1., "bias": 2., "other": 3.}, + }) # pyformat: disable + + config = ml_collections.ConfigDict() + config.lr = 0.01 + config.schedule = dict(decay_type="linear") + config.optax_name = "scale" + config.grad_clip_norm = 1.0 + config.optax = ml_collections.ConfigDict() + g_scale = 0.5 + config.optax.step_size = g_scale + + total_steps = 10 + sched_kw = dict(global_batch_size=1, total_steps=total_steps) + tx, (sched_fn,) = bv_optax.make(config, params, sched_kw=sched_kw) + opt_state = tx.init(params) + + grads = jax.tree.map(jnp.ones_like, params) + gflat = jax.tree.leaves(grads) + l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in gflat])) + grad_clip_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) + grads_scaled = jax.tree.map(lambda p: grad_clip_factor * p, grads) + + for step in range(total_steps): + updates, opt_state = tx.update(grads, opt_state) + self.assertEqual(bv_optax.get_count(opt_state), step + 1) + sched = sched_fn(step) + np.testing.assert_almost_equal( + sched, 1.0 / total_steps * (total_steps - step)) + make_tx = lambda sched: lambda g: -sched * config.lr * g_scale * g + chex.assert_trees_all_close(updates, + jax.tree.map(make_tx(sched), grads_scaled)) + + def test_make_multi(self): + params = jax.tree.map( + jnp.array, { + "Dense_0": {"kernel": 1.0, "bias": 2.0, "other": 3.0}, + "Dense_1": {"kernel": 4.0, "bias": 5.0, "other": 6.0}, + "Dense_2": {"kernel": 7.0, "bias": 8.0, "other": 9.0}, + "Dense_3": {"kernel": 10., "bias": 11., "other": 12.}, + }) # pyformat: disable + + # Manually specify lr + wd for computing expected values. + lrb = 0.01 + lr1 = 2.0 + lr2 = 0.5 + lr_mults = { + "Dense_0": {"kernel": lr1, "bias": lr1, "other": lr1}, + "Dense_1": {"kernel": lr2, "bias": lr2, "other": lr2}, + "Dense_2": {"kernel": 1.0, "bias": 1.0, "other": 1.0}, + "Dense_3": {"kernel": 1.0, "bias": 1.0, "other": 1.0}, + } # pyformat: disable + wdb = 1e-3 + wd1 = 10.0 + wd2 = 0.1 + wds = jax.tree.map( + jnp.array, { + "Dense_0": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.}, + "Dense_1": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.}, + "Dense_2": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.}, + "Dense_3": {"kernel": 0.0 * wdb, "bias": 0.0 * wdb, "other": 0.}, + }) # pyformat: disable + + config = ml_collections.ConfigDict() + config.lr = lrb + config.lr_mults = [ + ("Dense_0/.*", lr1), + ("Dense_1/.*", lr2), + ] + config.wd = wdb + config.wd_mults = [ + (".*/kernel", wd1), + (".*/bias", wd2), + ] + mult1 = 1.0 + mult2 = 0.1 + config.schedule = [ + ("Dense_0/.*", dict(decay_type="linear", mult=mult1, linear_end=mult1)), + ("Dense_[12]/.*", dict(decay_type="linear", mult=mult2)), + (".*", None), + ] + config.optax_name = "scale" + config.grad_clip_norm = 1.0 + config.optax = ml_collections.ConfigDict() + g_scale = 0.5 + config.optax.step_size = g_scale + + total_steps = 10 + sched_kw = dict(global_batch_size=1, total_steps=total_steps) + tx, (sched_fn1, + sched_fn2) = bv_optax.make(config, params, sched_kw=sched_kw) + opt_state = tx.init(params) + + # Manually specify schedules for computing expected values. + frozen_fn = lambda _: jnp.array(0.) + sched_fns = { + "Dense_0": {"kernel": sched_fn1, "bias": sched_fn1, "other": sched_fn1}, + "Dense_1": {"kernel": sched_fn2, "bias": sched_fn2, "other": sched_fn2}, + "Dense_2": {"kernel": sched_fn2, "bias": sched_fn2, "other": sched_fn2}, + "Dense_3": {"kernel": frozen_fn, "bias": frozen_fn, "other": frozen_fn}, + } # pyformat: disable + + grads = jax.tree.map(jnp.ones_like, params) + gflat, _ = jax.tree.flatten( + # Don't count frozen params towards gradient norm. + jax.tree.map(lambda g, sched_fn: {frozen_fn: 0}.get(sched_fn, g), + grads, sched_fns)) + l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in gflat])) + grad_clip_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) + grads_scaled = jax.tree.map(lambda p: grad_clip_factor * p, grads) + + def make_tx(step): + def get_update(p, g, wd, sched_fn, lr_mult): + return -sched_fn(step) * (lrb * lr_mult * g_scale * g + p * wd) + return get_update + + for step in range(total_steps): + updates, opt_state = tx.update(grads, opt_state, params) + self.assertEqual(bv_optax.get_count(opt_state), step + 1) + sched1, sched2 = sched_fn1(step), sched_fn2(step) + np.testing.assert_almost_equal(sched1, mult1) + np.testing.assert_almost_equal(sched2, + mult2 * (total_steps - step) / total_steps) + chex.assert_trees_all_close( + updates, + jax.tree.map( + make_tx(step), params, grads_scaled, wds, sched_fns, lr_mults)) + + def test_frozen_no_state(self): + params = {"small": jnp.zeros([1]), "large": jnp.zeros([1000])} + config = ml_collections.ConfigDict() + config.lr = 0.01 + config.schedule = [ + ("small", dict(decay_type="cosine")), + ("large", None), + ] + config.optax_name = "scale_by_adam" + + sched_kw = dict(global_batch_size=1, total_steps=1) + tx, _ = bv_optax.make(config, params, sched_kw=sched_kw) + + opt_state = tx.init(params) + adam_state = bv_optax.find_states(opt_state, optax.ScaleByAdamState) + nbytes = sum( + jax.tree.flatten(jax.tree.map(lambda x: x.nbytes, adam_state))[0]) + self.assertLess(nbytes, 1_000) + + def test_adafactor(self): + params = {"Dense_0": {"kernel": jnp.zeros([1024, 1024])}} + + config = ml_collections.ConfigDict() + config.optax_name = "big_vision.scale_by_adafactor" + config.lr = 0.01 + config.schedule = dict(decay_type="linear") + sched_kw = dict(global_batch_size=1, total_steps=1) + + tx, _ = bv_optax.make(config, params, sched_kw=sched_kw) + + opt_state = tx.init(params) + adafactor_state = bv_optax.find_states(opt_state, optax.FactoredState) + n_state_params = sum( + jax.tree.flatten( + jax.tree.map(lambda x: np.prod( + x.shape if hasattr(x, "shape") else 0), adafactor_state))[0]) + self.assertEqual(n_state_params, 2 * 1024 + 2) + + +if __name__ == "__main__": + absltest.main() diff --git a/big_vision/pp/__init__.py b/big_vision/pp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/pp/__pycache__/__init__.cpython-310.pyc b/big_vision/pp/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/pp/__pycache__/registry.cpython-310.pyc b/big_vision/pp/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/pp/archive/__init__.py b/big_vision/pp/archive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/pp/archive/autoaugment.py b/big_vision/pp/archive/autoaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/pp/archive/randaug.py b/big_vision/pp/archive/randaug.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/pp/autoaugment.py b/big_vision/pp/autoaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..6cc45f14e5d8c49cb54c649104851e0729ebb180 --- /dev/null +++ b/big_vision/pp/autoaugment.py @@ -0,0 +1,700 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AutoAugment and RandAugment policies for enhanced image preprocessing. + +AutoAugment Reference: https://arxiv.org/abs/1805.09501 +RandAugment Reference: https://arxiv.org/abs/1909.13719 + +This code is forked from +https://github.com/tensorflow/tpu/blob/11d0db15cf1c3667f6e36fecffa111399e008acd/models/official/efficientnet/autoaugment.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import dataclasses +import inspect +import math +import tensorflow.compat.v1 as tf +from tensorflow_addons import image as contrib_image + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + + +@dataclasses.dataclass +class HParams: + """Parameters for AutoAugment and RandAugment.""" + cutout_const: int + translate_const: int + + +def policy_v0(): + """Autoaugment policy that was used in AutoAugment Paper.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + return policy + + +def policy_vtest(): + """Autoaugment test policy for debugging.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)], + ] + return policy + + +def blend(image1, image2, factor): + """Blend image1 and image2 using 'factor'. + Factor can be above 0.0. A value of 0.0 means only image1 is used. + A value of 1.0 means only image2 is used. A value between 0.0 and + 1.0 means we linearly interpolate the pixel values between the two + images. A value greater than 1.0 "extrapolates" the difference + between the two pixel values, and we clip the results to values + between 0 and 255. + Args: + image1: An image Tensor of type uint8. + image2: An image Tensor of type uint8. + factor: A floating point value above 0.0. + Returns: + A blended image Tensor of type uint8. + """ + if factor == 0.0: + return tf.convert_to_tensor(image1) + if factor == 1.0: + return tf.convert_to_tensor(image2) + + image1 = tf.to_float(image1) + image2 = tf.to_float(image2) + + difference = image2 - image1 + scaled = factor * difference + + # Do addition in float. + temp = tf.to_float(image1) + scaled + + # Interpolate + if factor > 0.0 and factor < 1.0: + # Interpolation means we always stay within 0 and 255. + return tf.cast(temp, tf.uint8) + + # Extrapolate: + # + # We need to clip and then cast. + return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8) + + +def cutout(image, pad_size, replace=0): + """Apply cutout (https://arxiv.org/abs/1708.04552) to image. + This operation applies a (2*pad_size x 2*pad_size) mask of zeros to + a random location within `img`. The pixel values filled in will be of the + value `replace`. The located where the mask will be applied is randomly + chosen uniformly over the whole image. + Args: + image: An image Tensor of type uint8. + pad_size: Specifies how big the zero mask that will be generated is that + is applied to the image. The mask will be of size + (2*pad_size x 2*pad_size). + replace: What pixel value to fill in the image in the area that has + the cutout mask applied to it. + Returns: + An image Tensor that is of type uint8. + """ + image_height = tf.shape(image)[0] + image_width = tf.shape(image)[1] + + # Sample the center location in the image where the zero mask will be applied. + cutout_center_height = tf.random_uniform( + shape=[], minval=0, maxval=image_height, + dtype=tf.int32) + + cutout_center_width = tf.random_uniform( + shape=[], minval=0, maxval=image_width, + dtype=tf.int32) + + lower_pad = tf.maximum(0, cutout_center_height - pad_size) + upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) + left_pad = tf.maximum(0, cutout_center_width - pad_size) + right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) + + cutout_shape = [image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad)] + padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] + mask = tf.pad( + tf.zeros(cutout_shape, dtype=image.dtype), + padding_dims, constant_values=1) + mask = tf.expand_dims(mask, -1) + mask = tf.tile(mask, [1, 1, 3]) + image = tf.where( + tf.equal(mask, 0), + tf.ones_like(image, dtype=image.dtype) * replace, + image) + return image + + +def solarize(image, threshold=128): + # For each pixel in the image, select the pixel + # if the value is less than the threshold. + # Otherwise, subtract 255 from the pixel. + return tf.where(image < threshold, image, 255 - image) + + +def solarize_add(image, addition=0, threshold=128): + # For each pixel in the image less than threshold + # we add 'addition' amount to it and then clip the + # pixel value to be between 0 and 255. The value + # of 'addition' is between -128 and 128. + added_image = tf.cast(image, tf.int64) + addition + added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8) + return tf.where(image < threshold, added_image, image) + + +def color(image, factor): + """Equivalent of PIL Color.""" + degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) + return blend(degenerate, image, factor) + + +def contrast(image, factor): + """Equivalent of PIL Contrast.""" + degenerate = tf.image.rgb_to_grayscale(image) + # Cast before calling tf.histogram. + degenerate = tf.cast(degenerate, tf.int32) + + # Compute the grayscale histogram, then compute the mean pixel value, + # and create a constant image size of that value. Use that as the + # blending degenerate target of the original image. + hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) + mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 + degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) + return blend(degenerate, image, factor) + + +def brightness(image, factor): + """Equivalent of PIL Brightness.""" + degenerate = tf.zeros_like(image) + return blend(degenerate, image, factor) + + +def posterize(image, bits): + """Equivalent of PIL Posterize.""" + shift = 8 - bits + return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) + + +def rotate(image, degrees, replace): + """Rotates the image by degrees either clockwise or counterclockwise. + Args: + image: An image Tensor of type uint8. + degrees: Float, a scalar angle in degrees to rotate all images by. If + degrees is positive the image will be rotated clockwise otherwise it will + be rotated counterclockwise. + replace: A one or three value 1D tensor to fill empty pixels caused by + the rotate operation. + Returns: + The rotated version of image. + """ + # Convert from degrees to radians. + degrees_to_radians = math.pi / 180.0 + radians = degrees * degrees_to_radians + + # In practice, we should randomize the rotation degrees by flipping + # it negatively half the time, but that's done on 'degrees' outside + # of the function. + image = contrib_image.rotate(wrap(image), radians) + return unwrap(image, replace) + + +def translate_x(image, pixels, replace): + """Equivalent of PIL Translate in X dimension.""" + image = contrib_image.translate(wrap(image), [-pixels, 0]) + return unwrap(image, replace) + + +def translate_y(image, pixels, replace): + """Equivalent of PIL Translate in Y dimension.""" + image = contrib_image.translate(wrap(image), [0, -pixels]) + return unwrap(image, replace) + + +def shear_x(image, level, replace): + """Equivalent of PIL Shearing in X dimension.""" + # Shear parallel to x axis is a projective transform + # with a matrix form of: + # [1 level + # 0 1]. + image = contrib_image.transform( + wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + return unwrap(image, replace) + + +def shear_y(image, level, replace): + """Equivalent of PIL Shearing in Y dimension.""" + # Shear parallel to y axis is a projective transform + # with a matrix form of: + # [1 0 + # level 1]. + image = contrib_image.transform( + wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + return unwrap(image, replace) + + +def autocontrast(image): + """Implements Autocontrast function from PIL using TF ops. + Args: + image: A 3D uint8 tensor. + Returns: + The image after it has had autocontrast applied to it and will be of type + uint8. + """ + + def scale_channel(image): + """Scale the 2D image using the autocontrast rule.""" + # A possibly cheaper version can be done using cumsum/unique_with_counts + # over the histogram values, rather than iterating over the entire image. + # to compute mins and maxes. + lo = tf.to_float(tf.reduce_min(image)) + hi = tf.to_float(tf.reduce_max(image)) + + # Scale the image, making the lowest value 0 and the highest value 255. + def scale_values(im): + scale = 255.0 / (hi - lo) + offset = -lo * scale + im = tf.to_float(im) * scale + offset + im = tf.clip_by_value(im, 0.0, 255.0) + return tf.cast(im, tf.uint8) + + result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image) + return result + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image[:, :, 0]) + s2 = scale_channel(image[:, :, 1]) + s3 = scale_channel(image[:, :, 2]) + image = tf.stack([s1, s2, s3], 2) + return image + + +def sharpness(image, factor): + """Implements Sharpness function from PIL using TF ops.""" + orig_image = image + image = tf.cast(image, tf.float32) + # Make image 4D for conv operation. + image = tf.expand_dims(image, 0) + # SMOOTH PIL Kernel. + kernel = tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, + shape=[3, 3, 1, 1]) / 13. + # Tile across channel dimension. + kernel = tf.tile(kernel, [1, 1, 3, 1]) + strides = [1, 1, 1, 1] + with tf.device('/cpu:0'): + # Some augmentation that uses depth-wise conv will cause crashing when + # training on GPU. See ((internal link)) for details. + degenerate = tf.nn.depthwise_conv2d( + image, kernel, strides, padding='VALID', rate=[1, 1]) + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) + + # For the borders of the resulting image, fill in the values of the + # original image. + mask = tf.ones_like(degenerate) + padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) + padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) + result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) + + # Blend the final result. + return blend(result, orig_image, factor) + + +def equalize(image): + """Implements Equalize function from PIL using TF ops.""" + def scale_channel(im, c): + """Scale the data in the channel to implement equalize.""" + im = tf.cast(im[:, :, c], tf.int32) + # Compute the histogram of the image channel. + histo = tf.histogram_fixed_width(im, [0, 255], nbins=256) + + # For the purposes of computing the step, filter out the nonzeros. + nonzero = tf.where(tf.not_equal(histo, 0)) + nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1]) + step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255 + + def build_lut(histo, step): + # Compute the cumulative sum, shifting by step // 2 + # and then normalization by step. + lut = (tf.cumsum(histo) + (step // 2)) // step + # Shift lut, prepending with 0. + lut = tf.concat([[0], lut[:-1]], 0) + # Clip the counts to be in range. This is done + # in the C code for image.point. + return tf.clip_by_value(lut, 0, 255) + + # If step is zero, return the original image. Otherwise, build + # lut from the full histogram and step and then index from it. + result = tf.cond(tf.equal(step, 0), + lambda: im, + lambda: tf.gather(build_lut(histo, step), im)) + + return tf.cast(result, tf.uint8) + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image, 0) + s2 = scale_channel(image, 1) + s3 = scale_channel(image, 2) + image = tf.stack([s1, s2, s3], 2) + return image + + +def invert(image): + """Inverts the image pixels.""" + image = tf.convert_to_tensor(image) + return 255 - image + + +def wrap(image): + """Returns 'image' with an extra channel set to all 1s.""" + shape = tf.shape(image) + extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype) + extended = tf.concat([image, extended_channel], 2) + return extended + + +def unwrap(image, replace): + """Unwraps an image produced by wrap. + Where there is a 0 in the last channel for every spatial position, + the rest of the three channels in that spatial dimension are grayed + (set to 128). Operations like translate and shear on a wrapped + Tensor will leave 0s in empty locations. Some transformations look + at the intensity of values to do preprocessing, and we want these + empty pixels to assume the 'average' value, rather than pure black. + Args: + image: A 3D Image Tensor with 4 channels. + replace: A one or three value 1D tensor to fill empty pixels. + Returns: + image: A 3D image Tensor with 3 channels. + """ + image_shape = tf.shape(image) + # Flatten the spatial dimensions. + flattened_image = tf.reshape(image, [-1, image_shape[2]]) + + # Find all pixels where the last channel is zero. + alpha_channel = flattened_image[:, 3] + + replace = tf.concat([replace, tf.ones([1], image.dtype)], 0) + + # Where they are zero, fill them in with 'replace'. + flattened_image = tf.where( + tf.equal(alpha_channel, 0), + tf.ones_like(flattened_image, dtype=image.dtype) * replace, + flattened_image) + + image = tf.reshape(flattened_image, image_shape) + image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) + return image + + +NAME_TO_FUNC = { + 'AutoContrast': autocontrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x, + 'TranslateY': translate_y, + 'Cutout': cutout, +} + + +def _randomly_negate_tensor(tensor): + """With 50% prob turn the tensor negative.""" + should_flip = tf.cast(tf.floor(tf.random_uniform([]) + 0.5), tf.bool) + final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor) + return final_tensor + + +def _rotate_level_to_arg(level): + level = (level/_MAX_LEVEL) * 30. + level = _randomly_negate_tensor(level) + return (level,) + + +def _shrink_level_to_arg(level): + """Converts level to ratio by which we shrink the image content.""" + if level == 0: + return (1.0,) # if level is zero, do not shrink the image + # Maximum shrinking ratio is 2.9. + level = 2. / (_MAX_LEVEL / level) + 0.9 + return (level,) + + +def _enhance_level_to_arg(level): + return ((level/_MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level): + level = (level/_MAX_LEVEL) * 0.3 + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def _translate_level_to_arg(level, translate_const): + level = (level/_MAX_LEVEL) * float(translate_const) + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def level_to_arg(hparams): + return { + 'AutoContrast': lambda level: (), + 'Equalize': lambda level: (), + 'Invert': lambda level: (), + 'Rotate': _rotate_level_to_arg, + 'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4),), + 'Solarize': lambda level: (int((level/_MAX_LEVEL) * 256),), + 'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110),), + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams.cutout_const),), + 'TranslateX': lambda level: _translate_level_to_arg( + level, hparams.translate_const), + 'TranslateY': lambda level: _translate_level_to_arg( + level, hparams.translate_const), + # pylint:enable=g-long-lambda + } + + +def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams): + """Return the function that corresponds to `name` and update `level` param.""" + func = NAME_TO_FUNC[name] + args = level_to_arg(augmentation_hparams)[name](level) + + # Check to see if prob is passed into function. This is used for operations + # where we alter bboxes independently. + # pytype:disable=wrong-arg-types + if 'prob' in inspect.getfullargspec(func).args: + args = tuple([prob] + list(args)) + # pytype:enable=wrong-arg-types + + # Add in replace arg if it is required for the function that is being called. + # pytype:disable=wrong-arg-types + if 'replace' in inspect.getfullargspec(func).args: + # Make sure replace is the final argument + assert 'replace' == inspect.getfullargspec(func).args[-1] + args = tuple(list(args) + [replace_value]) + # pytype:enable=wrong-arg-types + + return (func, prob, args) + + +def _apply_func_with_prob(func, image, args, prob): + """Apply `func` to image w/ `args` as input with probability `prob`.""" + assert isinstance(args, tuple) + + # If prob is a function argument, then this randomness is being handled + # inside the function, so make sure it is always called. + # pytype:disable=wrong-arg-types + if 'prob' in inspect.getfullargspec(func).args: + prob = 1.0 + # pytype:enable=wrong-arg-types + + # Apply the function with probability `prob`. + should_apply_op = tf.cast( + tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool) + augmented_image = tf.cond( + should_apply_op, + lambda: func(image, *args), + lambda: image) + return augmented_image + + +def select_and_apply_random_policy(policies, image): + """Select a random policy from `policies` and apply it to `image`.""" + policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32) + # Note that using tf.case instead of tf.conds would result in significantly + # larger graphs and would even break export for some larger policies. + for (i, policy) in enumerate(policies): + image = tf.cond( + tf.equal(i, policy_to_select), + lambda selected_policy=policy: selected_policy(image), + lambda: image) + return image + + +def build_and_apply_nas_policy(policies, image, + augmentation_hparams): + """Build a policy from the given policies passed in and apply to image. + Args: + policies: list of lists of tuples in the form `(func, prob, level)`, `func` + is a string name of the augmentation function, `prob` is the probability + of applying the `func` operation, `level` is the input argument for + `func`. + image: tf.Tensor that the resulting policy will be applied to. + augmentation_hparams: Hparams associated with the NAS learned policy. + Returns: + A version of image that now has data augmentation applied to it based on + the `policies` pass into the function. + """ + replace_value = [128, 128, 128] + + # func is the string name of the augmentation function, prob is the + # probability of applying the operation and level is the parameter associated + # with the tf op. + + # tf_policies are functions that take in an image and return an augmented + # image. + tf_policies = [] + for policy in policies: + tf_policy = [] + # Link string name to the correct python function and make sure the correct + # argument is passed into that function. + for policy_info in policy: + policy_info = list(policy_info) + [replace_value, augmentation_hparams] + + tf_policy.append(_parse_policy_info(*policy_info)) + # Now build the tf policy that will apply the augmentation procedue + # on image. + def make_final_policy(tf_policy_): + def final_policy(image_): + for func, prob, args in tf_policy_: + image_ = _apply_func_with_prob( + func, image_, args, prob) + return image_ + return final_policy + tf_policies.append(make_final_policy(tf_policy)) + + augmented_image = select_and_apply_random_policy( + tf_policies, image) + return augmented_image + + +def distort_image_with_autoaugment(image, augmentation_name): + """Applies the AutoAugment policy to `image`. + AutoAugment is from the paper: https://arxiv.org/abs/1805.09501. + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + augmentation_name: The name of the AutoAugment policy to use. The available + options are `v0` and `test`. `v0` is the policy used for + all of the results in the paper and was found to achieve the best results + on the COCO dataset. `v1`, `v2` and `v3` are additional good policies + found on the COCO dataset that have slight variation in what operations + were used during the search procedure along with how many operations are + applied in parallel to a single image (2 vs 3). + Returns: + A tuple containing the augmented versions of `image`. + """ + available_policies = {'v0': policy_v0, + 'test': policy_vtest} + if augmentation_name not in available_policies: + raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name)) + + policy = available_policies[augmentation_name]() + # Hparams that will be used for AutoAugment. + augmentation_hparams = HParams( + cutout_const=100, translate_const=250) + + return build_and_apply_nas_policy(policy, image, augmentation_hparams) + + +def distort_image_with_randaugment(image, num_layers, magnitude): + """Applies the RandAugment policy to `image`. + RandAugment is from the paper https://arxiv.org/abs/1909.13719, + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + num_layers: Integer, the number of augmentation transformations to apply + sequentially to an image. Represented as (N) in the paper. Usually best + values will be in the range [1, 3]. + magnitude: Integer, shared magnitude across all augmentation operations. + Represented as (M) in the paper. Usually best values are in the range + [5, 30]. + Returns: + The augmented version of `image`. + """ + replace_value = [128] * 3 + tf.logging.info('Using RandAug.') + augmentation_hparams = HParams( + cutout_const=40, translate_const=100) + available_ops = [ + 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', + 'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness', + 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'] + + for layer_num in range(num_layers): + op_to_select = tf.random_uniform( + [], maxval=len(available_ops), dtype=tf.int32) + random_magnitude = float(magnitude) + with tf.name_scope('randaug_layer_{}'.format(layer_num)): + for (i, op_name) in enumerate(available_ops): + prob = tf.random_uniform([], minval=0.2, maxval=0.8, dtype=tf.float32) + func, _, args = _parse_policy_info(op_name, prob, random_magnitude, + replace_value, augmentation_hparams) + image = tf.cond( + tf.equal(i, op_to_select), + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args), + # pylint:enable=g-long-lambda + lambda: image) + return image diff --git a/big_vision/pp/builder.py b/big_vision/pp/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..8e254cdfdd17c3b58e1c9522bbc2a5be1bb089b9 --- /dev/null +++ b/big_vision/pp/builder.py @@ -0,0 +1,85 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preprocessing builder.""" + +from absl import logging +from big_vision.pp import registry +import tensorflow as tf + + +def get_preprocess_fn(pp_pipeline, log_data=True, log_steps=False): + """Transform an input string into the preprocessing function. + + The minilanguage is as follows: + + fn1|fn2(arg, arg2,...)|... + + And describes the successive application of the various `fn`s to the input, + where each function can optionally have one or more arguments, which are + either positional or key/value, as dictated by the `fn`. + + The output preprocessing function expects a dictionary as input. This + dictionary should have a key "image" that corresponds to a 3D tensor + (height x width x channel). + + Args: + pp_pipeline: A string describing the pre-processing pipeline. If empty or + None, no preprocessing will be executed. + log_data: Whether to log the data before and after preprocessing. Can also + be a string to show in the log for debugging, for example dataset name. + log_steps: Whether to log the steps of the preprocessing pipeline. + + Returns: + preprocessing function. + + Raises: + ValueError: if preprocessing function name is unknown + """ + + names, ops, spec_strings = [], [], [] + if pp_pipeline: + for op_spec in pp_pipeline.split("|"): + if not op_spec: continue # Skip empty section instead of error. + try: + ops.append(registry.Registry.lookup(f"preprocess_ops.{op_spec}")()) + names.append(registry.parse_name(op_spec)[0]) + spec_strings.append(op_spec) + except SyntaxError as err: + raise ValueError(f"Syntax error on: {op_spec}") from err + + def _preprocess_fn(data): + """The preprocessing function that is returned.""" + nonlocal log_data, log_steps + + # Apply all the individual steps in sequence. + if log_data: + logging.info("Data before pre-processing (%s):\n%s", log_data, data) + for name, op, spec in zip(names, ops, spec_strings): + if log_steps: + logging.info("Pre-processing step (%s): %s\n%s", name, spec, data) + with tf.name_scope(name): + data = op(data) + + # Validate input + if not isinstance(data, dict): + raise ValueError("Argument `data` must be a dictionary, " + "not %s" % str(type(data))) + + if log_data: + logging.info("Data after pre-processing (%s):\n%s", log_data, data) + log_data = False # For eager&pygrain: only log first one of each pipeline. + return data + + return _preprocess_fn diff --git a/big_vision/pp/builder_test.py b/big_vision/pp/builder_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a75cc05417a7f3ace70f97583d2e7b1ae4c432 --- /dev/null +++ b/big_vision/pp/builder_test.py @@ -0,0 +1,72 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for builder.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from big_vision.pp import builder +from big_vision.pp import ops_general # pylint: disable=unused-import +from big_vision.pp import ops_image # pylint: disable=unused-import +import numpy as np +import tensorflow.compat.v1 as tf + + +class BuilderTest(tf.test.TestCase): + + def testSingle(self): + pp_fn = builder.get_preprocess_fn("resize(256)") + x = np.random.randint(0, 256, [640, 480, 3]) + image = pp_fn({"image": x})["image"] + self.assertEqual(image.numpy().shape, (256, 256, 3)) + + def testEmpty(self): + pp_fn = builder.get_preprocess_fn("||inception_crop|||resize(256)||") + + # Typical image input + x = np.random.randint(0, 256, [640, 480, 3]) + image = pp_fn({"image": x})["image"] + self.assertEqual(image.numpy().shape, (256, 256, 3)) + + def testPreprocessingPipeline(self): + pp_str = ("inception_crop|resize(256)|resize((256, 256))|" + "central_crop((80, 120))|flip_lr|value_range(0,1)|" + "value_range(-1,1)") + pp_fn = builder.get_preprocess_fn(pp_str) + + # Typical image input + x = np.random.randint(0, 256, [640, 480, 3]) + image = pp_fn({"image": x})["image"] + self.assertEqual(image.numpy().shape, (80, 120, 3)) + self.assertLessEqual(np.max(image.numpy()), 1) + self.assertGreaterEqual(np.min(image.numpy()), -1) + + def testNumArgsException(self): + + x = np.random.randint(0, 256, [640, 480, 3]) + for pp_str in [ + "inception_crop(1)", + "resize()", + "resize(1, 1, 1)" + "flip_lr(1)", + "central_crop()", + ]: + with self.assertRaises(BaseException): + builder.get_preprocess_fn(pp_str)(x) + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/pp/ops_general.py b/big_vision/pp/ops_general.py new file mode 100644 index 0000000000000000000000000000000000000000..2a5cebd07b34b9d4ba56e5da52d7a06125a04304 --- /dev/null +++ b/big_vision/pp/ops_general.py @@ -0,0 +1,465 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic tensor preprocessing ops. + +All preprocessing ops should return a data processing functors. A data +is represented as a dictionary of (TF) tensors. The functors output a modified +dictionary. +""" + +import collections + +from big_vision.pp import utils +from big_vision.pp.registry import Registry +import big_vision.utils as bv_utils +import jax +import numpy as np +import tensorflow as tf + + +@Registry.register("preprocess_ops.value_range") +@utils.InKeyOutKey() +def get_value_range(vmin=-1, vmax=1, in_min=0, in_max=255.0, clip_values=False): + """Transforms a [in_min,in_max] image to [vmin,vmax] range. + + Input ranges in_min/in_max can be equal-size lists to rescale the invidudal + channels independently. + + Args: + vmin: A scalar. Output max value. + vmax: A scalar. Output min value. + in_min: A scalar or a list of input min values to scale. If a list, the + length should match to the number of channels in the image. + in_max: A scalar or a list of input max values to scale. If a list, the + length should match to the number of channels in the image. + clip_values: Whether to clip the output values to the provided ranges. + + Returns: + A function to rescale the values. + """ + + def _value_range(image): + """Scales values in given range.""" + in_min_t = tf.constant(in_min, tf.float32) + in_max_t = tf.constant(in_max, tf.float32) + image = tf.cast(image, tf.float32) + image = (image - in_min_t) / (in_max_t - in_min_t) + image = vmin + image * (vmax - vmin) + if clip_values: + image = tf.clip_by_value(image, vmin, vmax) + return image + + return _value_range + + +@Registry.register("preprocess_ops.lookup") +@utils.InKeyOutKey() +def get_lookup(mapping, npzkey="fnames", sep=None): + """Map string to number.""" + + # For NumPy files, we use the `npzkey` array in that file as the list of + # strings which are mapped to their index in that array. + # This is especially useful when other data (eg precomputed predictions) + # goes along with this mapping, to have everything in one place (the npz). + if mapping.endswith(".npz"): + with tf.io.gfile.GFile(mapping, "rb") as f: + keys = np.array(np.load(f, allow_pickle=False)[npzkey]) + vals = np.arange(len(keys)) + + # Otherwise, we simply use the file as a text file, with either of: + # - a string per line, mapped to its line-number + # - a pair, separated by `sep` per line, first value being the string, second + # value being the integer that the string is mapped to. + else: + with tf.io.gfile.GFile(mapping, "r") as f: + buf = f.read() + if sep is None: # values are the line numbers + keys = buf.splitlines() + vals = np.arange(len(keys)) + else: # each line is keyval, also make val int + keys, vals = zip(*[l.split(sep) for l in buf.splitlines()]) + vals = [int(v) for v in vals] + + def _do_the_mapping(needle): + """Map string to number.""" + with tf.init_scope(): # (Originally added for performance reasons.) + table = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(keys, vals), -1) + return table.lookup(needle) + + return _do_the_mapping + + +@Registry.register("preprocess_ops.onehot") +def get_onehot(depth, + key="labels", + key_result=None, + multi=True, + on=1.0, + off=0.0): + """One-hot encodes the input. + + Args: + depth: Length of the one-hot vector (how many classes). + key: Key of the data to be one-hot encoded. + key_result: Key under which to store the result (same as `key` if None). + multi: If there are multiple labels, whether to merge them into the same + "multi-hot" vector (True) or keep them as an extra dimension (False). + on: Value to fill in for the positive label (default: 1). + off: Value to fill in for negative labels (default: 0). + + Returns: + Data dictionary. + """ + + def _onehot(data): + # When there's more than one label, this is significantly more efficient + # than using tf.one_hot followed by tf.reduce_max; we tested. + labels = data[key] + labels = tf.cast(labels, tf.int64) # both scatter and one_hot expect this + if labels.shape.rank > 0 and multi: + x = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(labels)[0]), (depth,)) + x = tf.clip_by_value(x, 0, 1) * (on - off) + off + else: + x = tf.one_hot(labels, depth, on_value=on, off_value=off) + data[key_result or key] = x + return data + + return _onehot + + +@Registry.register("preprocess_ops.keep") +def get_keep(*keys): + """Keeps only the given keys.""" + + def _keep(data): + return {k: v for k, v in data.items() if k in keys} + + return _keep + + +@Registry.register("preprocess_ops.drop") +def get_drop(*keys): + """Drops the given keys.""" + + def _drop(data): + return {k: v for k, v in data.items() if k not in keys} + + return _drop + + +@Registry.register("preprocess_ops.copy") +def get_copy(inkey, outkey): + """Copies value of `inkey` into `outkey`.""" + + def _copy(data): + # A "semi-deep" copy. deepcopy doesn't work when tf tensors are part of the + # game. What we want, is to only copy the python structure (dicts, lists) + # and keep tensors as they are, since we never modify them in-place anyways. + # The following achieves exactly that. + data[outkey] = jax.tree.map(lambda x: x, data[inkey]) + return data + + return _copy + + +@Registry.register("preprocess_ops.squeeze_last_dim") +@utils.InKeyOutKey() +def get_squeeze_last_dim(): + def _squeeze_last_dim(x): + return tf.squeeze(x, axis=-1) + return _squeeze_last_dim + + +@Registry.register("preprocess_ops.concat") +def get_concat(inkeys, outkey=None, axis=-1): + """Concatenates elements along some axis.""" + + def _concat(data): + data[outkey or inkeys[0]] = tf.concat([data[k] for k in inkeys], axis) + return data + + return _concat + + +@Registry.register("preprocess_ops.rag_tensor") +@utils.InKeyOutKey() +def get_rag_tensor(): + """Converts the specified feature to ragged tensor.""" + + def rag_tensor(raw_tensor): + # Note: Add one more dimension as `from_tensor` requires at least rank 2. + return tf.RaggedTensor.from_tensor(raw_tensor[None]) + + return rag_tensor + + +@Registry.register("preprocess_ops.pad_to_shape") +@utils.InKeyOutKey() +def get_pad_to_shape(shape, pad_value=0, where="after"): + """Pads tensor to specified `shape`.""" + + def _pads(cur, tgt): + if tgt is None: + return [0, 0] + diff = tgt - cur + return { + "before": [diff, 0], + "after": [0, diff], + "both": [diff // 2, diff - diff // 2], + }[where] + + def _pad_to_shape(x): + assert len(x.shape.as_list()) == len(shape) + paddings = [_pads(tgt=shape[i], cur=tf.shape(x)[i]) + for i in range(len(shape))] + constant_value = tf.constant(pad_value, x.dtype) + ret = tf.pad(x, paddings, constant_values=constant_value) + ret.set_shape(shape) + return ret + + return _pad_to_shape + + +@Registry.register("preprocess_ops.flatten") +def get_flatten(): + """Flattens the keys of data with separator '/'.""" + + def flatten(data): + flat, _ = bv_utils.tree_flatten_with_names(data) + return dict(flat) + + return flatten + + +@Registry.register("preprocess_ops.reshape") +@utils.InKeyOutKey() +def get_reshape(new_shape): + """Reshapes tensor to a given new shape. + + Args: + new_shape: new shape for the tensor. + + Returns: + A function for reshaping a tensor. + + """ + + def _reshape(tensor): + """Reshapes a tensor to a given shape.""" + dtype = tensor.dtype + tensor = tf.reshape(tensor, new_shape) + return tf.cast(tensor, dtype) + + return _reshape + + +@Registry.register("preprocess_ops.setdefault") +def get_setdefault(key, value): + """If `key` is an empty tensor or missing, set it to `value`.""" + def _setdefault(data): + x = data.get(key, tf.constant(value)) + v = tf.constant(value, dtype=x.dtype) + v = tf.broadcast_to(v, [s or 1 for s in x.shape]) + data[key] = tf.cond(tf.size(x) > 0, lambda: x, lambda: v) + return data + return _setdefault + + +@Registry.register("preprocess_ops.choice") +def get_choice(n="single", key=None, fewer_ok=False, inkey=None, outkey=None): + """Chooses the same `n` random entries of all `keys`. + + Args: + n: how many entries to randomly sample (without repeat). Possible values: + - int: that many entries (or fewer if there's fewer, see `fewer_ok`.) + - "single": The string "single" only chooses one and drop the leading dim. + - [min, max]: A pair means randomly take between min/max examples (incl.). + key: str or list of str: See Note. + fewer_ok: whether to fail when there's fewer than `n` elements to choose + from (and hence set static shape to `n`), or whether to allow it. + (and hence have unknown static shape). + inkey: str or list of str: See Note. + outkey: str or list of str: See Note. + + Note: + If key/inkey/outkey is a list, then the same random entries are chosen for + all of the keys. Other than that, they function the same as InKeyOutKey. + + The outkey can also contain the placeholder `{key}` that'll be . + + Examples: + choice(key="alt_text/text") + choice(n=128, key=["patches", "positions"]) + choice(inkey=["questions_i18n", "answers_i18n"], outkey=["q", "a"]) + + Returns: + The pp op. + """ + + # Normalize keys: + inkeys = utils.maybe_repeat(inkey or key, 1) + outkeys = utils.maybe_repeat(outkey or key, 1) + outkeys = [ok.format(key=ik) for ok, ik in zip(outkeys, inkeys)] + + # Let's DRY on this condition and give it a name. + is_varlen = isinstance(n, (list, tuple)) + min_n = n[0] if is_varlen else 1 if n == "single" else n + + def _choice(data): + # Catch a hard to identify/understand user error: + assert data[inkeys[0]].ndim > 0, ( + f"You're calling `choice_no_replacement` on {inkeys}, a scalar." + " That's probably a mistake ; double-check and then just don't." + ) + + nitems = tf.shape(data[inkeys[0]])[0] + + # Sanity check that all keys have same leading dimension, and that is at + # least as large as the minimum requested output. + lengths = [tf.shape(data[k])[0] for k in inkeys] + checks = [tf.debugging.assert_equal(l, nitems) for l in lengths] + if not fewer_ok: # Since we check for all-same, a single suffices here. + checks.append(tf.debugging.assert_greater_equal(nitems, min_n)) + with tf.control_dependencies(checks): + nitems = tf.identity(nitems) + + if n == "single": + index = tf.random.uniform([], 0, nitems, dtype=tf.int32) + else: + # Subsample by shuffling and taking first n, but... + indices = tf.random.shuffle(tf.range(nitems)) + end = n + if is_varlen: + end = tf.random.uniform([], n[0], n[1] + 1, dtype=tf.int32) + # ...keep the order while subsampling (it might have a meaning, eg boxes) + indices = tf.sort(indices[:end]) + + for ik, ok in zip(inkeys, outkeys): + if n == "single": + result = data[ik][index] + else: + result = tf.gather(data[ik], indices, axis=0) + if not is_varlen: # Give static shape when we can. + result = tf.ensure_shape(result, [n] + [None] * (result.ndim - 1)) + data[ok] = result + + return data + return _choice + + +def _shuffled_index(count, nitems, seed): + """Returns index from a shuffled sequence (items only repeat after epoch).""" + nitems = tf.cast(nitems, count.dtype) + item_epoch, item_offset = (count // nitems, count % nitems) + shuffled_indices = tf.random.experimental.stateless_shuffle( + tf.range(nitems), seed=tf.random.fold_in(seed, item_epoch)) + return shuffled_indices[item_offset] + + +@Registry.register("preprocess_ops.choice_no_replacement") +def get_choice_no_replacement(key=None, inkey=None, outkey=None): + """Chooses the same random (no replacement) entry of all `keys`. + + Note: Consider using this for iterating over small datasets with a small + number of epochs. It differs from `choice(n='single')` in that if an example, + as identified by its `_id` field, is seen N times then it will cycled through + all the inkeys values before repeating them. Additionally each repetition uses + a different order. + + Caveats: requires dataset to provide a _id field and uses host RAM to keep a + counter how often each id is seen. It is also not robust to preemptions. + + Args: + key: str or list of str: See Note. + inkey: str or list of str: See Note. + outkey: str or list of str: See Note. + + Note: + If key/inkey/outkey is a list, then the same random entries are chosen for + all of the keys. Other than that, they function the same as InKeyOutKey. + + The outkey can also contain the placeholder `{key}` that'll be replaced + by the inkey name. + + Examples: + choice(key="alt_text/text") + choice(key=["patches", "positions"]) + choice(inkey=["questions_i18n", "answers_i18n"], outkey=["q", "a"]) + + Returns: + The pp op. + """ + # Normalize keys: + inkeys = utils.maybe_repeat(inkey or key, 1) + outkeys = utils.maybe_repeat(outkey or key, 1) + outkeys = [ok.format(key=ik) for ok, ik in zip(outkeys, inkeys)] + + # TODO: Ideally the data pipeline should provide us with an epoch + # counter. For now count how often we see a given example id and don't worry + # on memory consumption. Counter returns 0 the first time an example is seen. + counter = collections.defaultdict(lambda: -1) + def _seen_count(example_id): + example_id = example_id.item() + counter[example_id] += 1 + return counter[example_id] + + # We need a seed to deterministically decide on a shuffled sequence and use + # the number of times an example was seen to iterate through it. The seed + # should be different for every instance of a create preprocessing function + # but it has to be fixed for each instance. + seed = tf.random.uniform( + [2], minval=tf.int32.min, maxval=tf.int32.max, dtype=tf.int32) + + def _choice(data): + # Catch a hard to identify/understand user error: + assert data[inkeys[0]].ndim > 0, ( + f"You're calling `choice` on {inkeys}, a scalar." + " That's probably a mistake ; double-check and then just don't." + ) + + nitems = tf.shape(data[inkeys[0]])[0] + + # Sanity check that all keys have same leading dimension. + checks = [ + tf.debugging.assert_equal(tf.shape(data[k])[0], nitems) + for k in inkeys + ] + with tf.control_dependencies(checks): + nitems = tf.identity(nitems) + + # Using the seed, example id and the number of times an example was seen + # pick an `index` such that items are only repeated after all items are seen + # an equal number of times. E.g. it could return indexes from this sequence: + # [0, 1, 2, 1, 2, 0, 2, 0, 1, 0, 2, 1, ...]. + count = tf.numpy_function( + _seen_count, (data["_id"],), Tout=tf.int64, stateful=True) + count = tf.cast(count, tf.int32) + nitems = tf.cast(nitems, tf.int32) + shuffle_epoch = count // nitems + shuffle_offset = count % nitems + + example_seed = tf.random.fold_in(seed, data["_id"]) + shuffle_seed = tf.random.fold_in(example_seed, shuffle_epoch) + shuffle = tf.random.experimental.stateless_shuffle( + tf.range(nitems), seed=shuffle_seed) + index = shuffle[shuffle_offset] + + # Select item[index] for all keys. + for ik, ok in zip(inkeys, outkeys): + data[ok] = data[ik][index] + return data + + return _choice diff --git a/big_vision/pp/ops_general_test.py b/big_vision/pp/ops_general_test.py new file mode 100644 index 0000000000000000000000000000000000000000..89f616e1690c6e83aff818cf0fff540dcad073fd --- /dev/null +++ b/big_vision/pp/ops_general_test.py @@ -0,0 +1,236 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ops_general.""" + +import copy + +import big_vision.pp.ops_general as pp +import numpy as np +import tensorflow as tf + + +class PreprocessOpsTest(tf.test.TestCase): + + def tfrun(self, ppfn, data): + # Run once as standalone, as could happen eg in colab. + yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()} + + # And then once again as part of tfdata pipeline. + # You'd be surprised how much these two differ! + tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) + for npdata in tfdata.map(ppfn).as_numpy_iterator(): + yield npdata + + def test_value_range(self): + img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32) + data = {"image": tf.cast(img, tf.uint8)} + for out in self.tfrun(pp.get_value_range(-0.5, 0.5), data): + self.assertLessEqual(np.max(out["image"]), 0.5) + self.assertGreaterEqual(np.min(out["image"]), -0.5) + + def test_value_range_custom_input_range(self): + img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32) + data = {"image": tf.cast(img, tf.uint8)} + for out in self.tfrun(pp.get_value_range(-0.5, 0.5, -256, 255, True), data): + self.assertLessEqual(np.max(out["image"]), 0.5) + self.assertGreaterEqual(np.min(out["image"]), 0.0) + + def test_get_keep_drop(self): + data = {"image": 1, "labels": 2, "something": 3} + + for data_keep in self.tfrun(pp.get_keep("image", "labels"), data): + self.assertAllEqual(set(data_keep.keys()), {"image", "labels"}) + + for data_drop in self.tfrun(pp.get_drop("image", "labels"), data): + self.assertAllEqual(set(data_drop.keys()), {"something"}) + + def test_onehot(self): + data = {"labels": tf.constant(2, dtype=tf.int64)} + for out in self.tfrun(pp.get_onehot(4, "labels", multi=True), data): + self.assertAllClose(out["labels"], [0., 0., 1., 0.]) + + def test_onehot_multi(self): + data = {"labels": tf.constant([2, 3, 0], dtype=tf.int64)} + for out in self.tfrun(pp.get_onehot(4, "labels", multi=False), data): + self.assertAllClose(out["labels"], [ + [0., 0., 1., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.]]) + + for out in self.tfrun(pp.get_onehot(4, "labels", multi=True), data): + self.assertAllClose(out["labels"], [1., 0., 1., 1.]) + + def test_onehot_2d(self): + data = {"labels": tf.constant([[2, 3], [0, 1]], dtype=tf.int64)} + for out in self.tfrun(pp.get_onehot(4, "labels", multi=False), data): + self.assertAllClose(out["labels"], [ + [[0., 0., 1., 0.], [0., 0., 0., 1.]], + [[1., 0., 0., 0.], [0., 1., 0., 0.]]]) + + def test_onehot_smoothing(self): + data = {"labels": tf.constant([2, 3, 0], dtype=tf.int64)} + for out in self.tfrun( + pp.get_onehot(4, "labels", multi=False, on=0.8, off=0.1), data): + self.assertAllClose(out["labels"], [ + [0.1, 0.1, 0.8, 0.1], + [0.1, 0.1, 0.1, 0.8], + [0.8, 0.1, 0.1, 0.1]]) + + for out in self.tfrun( + pp.get_onehot(4, "labels", multi=True, on=0.8, off=0.1), data): + self.assertAllClose(out["labels"], [0.8, 0.1, 0.8, 0.8]) + + def test_squeeze_last_dim(self): + data = {"image": tf.constant(np.zeros((32, 32, 3, 1)))} + for out in self.tfrun(pp.get_squeeze_last_dim(), data): + self.assertAllEqual(out["image"].shape, [32, 32, 3]) + + def test_pad_to_shape(self): + desired_shape = (8, 10) + for input_shape in [(8, 4), (8, 3), (8, 10), (8, 1)]: + data = {"x": tf.ones(input_shape, dtype=tf.float32)} + for out in self.tfrun( + pp.get_pad_to_shape(desired_shape, pad_value=-1, key="x"), data): + self.assertEqual( + tf.reduce_sum(out["x"]), + 2 * np.prod(input_shape) - np.prod(desired_shape)) + + def test_pad_to_shape_none(self): + data = {"x": tf.ones((8, 4), dtype=tf.float32)} + for out in self.tfrun( + pp.get_pad_to_shape((None, 6), pad_value=-1, key="x"), data): + self.assertEqual(out["x"].shape, (8, 6)) + self.assertEqual(tf.reduce_sum(out["x"]), 8*4 - 8*2) + + def test_pad_to_shape_which_side(self): + data = {"x": tf.ones((8, 4), dtype=tf.float32)} + for where, idxs in [("before", [0]), ("both", [0, -1]), ("after", [-1])]: + for out in self.tfrun( + pp.get_pad_to_shape((8, 6), key="x", where=where), data): + self.assertEqual(out["x"].shape, (8, 6)) + self.assertEqual(tf.reduce_sum(out["x"]), 8*4) + for i in idxs: + self.assertEqual(out["x"][0, i], 0) + + def test_flatten(self): + d = {"a": {"b": tf.constant([1, 2, 3])}, "c": "str"} + self.assertEqual(pp.get_flatten()(d), { + "a/b": tf.constant([1, 2, 3]), + "c": "str" + }) + + def test_reshape(self): + data = {"image": tf.constant(np.zeros((8, 32 * 32 * 3)))} + for out in self.tfrun(pp.get_reshape(new_shape=(8, 32, 32, 3)), data): + self.assertAllEqual(out["image"].shape, [8, 32, 32, 3]) + + def test_setdefault(self): + data = { + "empty_image": tf.zeros([0, 0, 0]), + "image": tf.constant(np.arange(9).reshape(3, 3)), + "empty_text": tf.zeros([0], tf.string), + "text": tf.constant(["Hello", "World"], tf.string), + } + for out in self.tfrun(pp.get_setdefault("empty_image", 1), data): + self.assertAllEqual(out["empty_image"], np.array([[[1]]])) + for out in self.tfrun(pp.get_setdefault("image", 1), data): + self.assertAllEqual(out["image"], data["image"]) + for out in self.tfrun(pp.get_setdefault("empty_text", "Lucas"), data): + self.assertAllEqual(out["empty_text"], np.array(["Lucas"])) + for out in self.tfrun(pp.get_setdefault("text", "Lucas"), data): + self.assertAllEqual(out["text"], data["text"]) + + def _data_for_choice(self): + return { + "one_f32": tf.constant([0.42], tf.float32), + "two_f32": tf.constant([3.14, 0.42], tf.float32), + "one_str": tf.constant(["Hi"], tf.string), + "two_str": tf.constant(["Hi", "Lucas"], tf.string), + "one_vec": tf.reshape(tf.range(2, dtype=tf.float32), (1, 2)), + "two_vec": tf.reshape(tf.range(4, dtype=tf.float32), (2, 2)), + } + + def test_choice(self): + # Test for the default call (n="single") + data = self._data_for_choice() + self.assertEqual( + pp.get_choice(inkey="one_f32", outkey="choice")(data)["choice"], 0.42) + self.assertEqual( + pp.get_choice(inkey="one_str", outkey="choice")(data)["choice"], "Hi") + self.assertIn( + pp.get_choice(inkey="two_f32", outkey="choice")(data)["choice"], + [3.14, 0.42]) + self.assertIn( + pp.get_choice(inkey="two_str", outkey="choice")(data)["choice"], + ["Hi", "Lucas"]) + + def test_choice_nmax(self): + # n == nelems should be identity (and keep ordering!) + data = self._data_for_choice() + for k in ("one_f32", "one_str", "one_vec"): + for out in self.tfrun(pp.get_choice(n=1, key=[k]), data): + self.assertAllEqual(out[k], data[k]) + for out in self.tfrun(pp.get_choice(n=[1, 1], key=[k]), data): + self.assertAllEqual(out[k], data[k]) + for k in ("two_f32", "two_str", "two_vec"): + for out in self.tfrun(pp.get_choice(n=2, key=[k]), data): + self.assertAllEqual(out[k], data[k]) + for out in self.tfrun(pp.get_choice(n=[2, 2], key=[k]), data): + self.assertAllEqual(out[k], data[k]) + + def test_choice_n(self): + # n < nelems should be one of them: + data = self._data_for_choice() + for k in ("two_f32", "two_str"): + for out in self.tfrun(pp.get_choice(n=1, key=[k]), data): + self.assertIn(out[k], data[k]) + + # Special testing for vectors. + for out in self.tfrun(pp.get_choice(n=1, key=["two_vec"]), data): + self.assertTrue(tf.logical_or( + tf.reduce_all(out["two_vec"][0] == data["two_vec"][0]), + tf.reduce_all(out["two_vec"][0] == data["two_vec"][1]), + )) + + def test_choice_multi(self): + # Select consistently across multiple keys. + data = self._data_for_choice() + op = pp.get_choice(n=1, key=["two_f32", "two_str"]) + for out in self.tfrun(op, data): + self.assertTrue(tf.logical_or( + tf.logical_and( + tf.reduce_all(out["two_f32"][0] == data["two_f32"][0]), + tf.reduce_all(out["two_str"][0] == data["two_str"][0]), + ), + tf.logical_and( + tf.reduce_all(out["two_f32"][0] == data["two_f32"][1]), + tf.reduce_all(out["two_str"][0] == data["two_str"][1]), + ), + )) + + def test_choice_n_range(self): + # n < nelems should be one of them: + data = self._data_for_choice() + for k in ("two_f32", "two_str", "two_vec"): + for out in self.tfrun(pp.get_choice(n=[1, 2], key=[k]), data): + self.assertTrue(tf.reduce_any([ + tf.reduce_all(out[k] == data[k][0:1]), + tf.reduce_all(out[k] == data[k][1:2]), + tf.reduce_all(out[k] == data[k][0:2]), + ])) + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/pp/ops_image.py b/big_vision/pp/ops_image.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc55a5659c3e0df0c209edef798bbd5c6a7f623 --- /dev/null +++ b/big_vision/pp/ops_image.py @@ -0,0 +1,361 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Image-centric preprocessing ops. + +All preprocessing ops should return a data processing functors. A data +is represented as a dictionary of (TF) tensors. The functors output a modified +dictionary. + +The key named "image" is commonly used for the image, and is a 3D tensor of +shape (height x width x channels). +""" + +from big_vision.pp import utils +from big_vision.pp.registry import Registry + +import tensorflow as tf + + +@Registry.register("preprocess_ops.decode") +@utils.InKeyOutKey() +def get_decode(channels=3, precise=False): + """Decode an encoded image string, see tf.io.decode_image. + + Args: + channels: see tf.io.decode_image. + precise: if False, use default TF image decoding algorithm. + If True, change DCT method for JPEG decoding to match PIL/cv2/PyTorch. + See also (internal link) for a concrete example. + + Returns: + The decoded image. + """ + + def _decode(image): + if precise: + return tf.image.decode_jpeg( # Also supports png btw. + image, channels=channels, dct_method="INTEGER_ACCURATE") + else: + return tf.io.decode_image( + image, channels=channels, expand_animations=False) + + return _decode + + +@Registry.register("preprocess_ops.resize") +@utils.InKeyOutKey() +def get_resize(size, method="bilinear", antialias=False): + """Resizes image to a given size. + + Args: + size: either an integer H, where H is both the new height and width + of the resized image, or a list or tuple [H, W] of integers, where H and W + are new image"s height and width respectively. + method: resize method, see tf.image.resize docs for options. + antialias: see tf.image.resize. Ideally set to True for all new configs. + + Returns: + A function for resizing an image. + + """ + size = utils.maybe_repeat(size, 2) + + def _resize(image): + """Resizes image to a given size.""" + # Note: use TF-2 version of tf.image.resize as the version in TF-1 is + # buggy: https://github.com/tensorflow/tensorflow/issues/6720. + # In particular it was not equivariant with rotation and lead to the network + # to learn a shortcut in self-supervised rotation task, if rotation was + # applied after resize. + dtype = image.dtype + tf_dtype = tf.type_spec_from_value(image).dtype + image = tf.image.resize(image, size, method=method, antialias=antialias) + return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype) + + return _resize + + +# This functionality is used by resize_small and resize_long. But we're not +# registering it as a pp op yet, as there is no need for it. However, it can +# probably be slightly generalized into "scale augmentation" eventually. +def _resize_factor(image, factor, method="area", antialias=True): + """Resizes the image by a (float) `factor`, keeping the aspect ratio fixed.""" + h, w = tf.shape(image)[0], tf.shape(image)[1] + + h = tf.cast(tf.round(tf.cast(h, tf.float32) * factor), tf.int32) + w = tf.cast(tf.round(tf.cast(w, tf.float32) * factor), tf.int32) + + dtype = image.dtype + tf_dtype = tf.type_spec_from_value(image).dtype + image = tf.image.resize(image, (h, w), method=method, antialias=antialias) + return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype) + + +@Registry.register("preprocess_ops.resize_small") +@utils.InKeyOutKey() +def get_resize_small(smaller_size, method="area", antialias=False): + """Resizes the smaller side to `smaller_size` keeping aspect ratio. + + Args: + smaller_size: an integer, that represents a new size of the smaller side of + an input image. + method: the resize method. `area` is a meaningful, bwd-compat default. + antialias: see tf.image.resize. Ideally set to True for all new configs. + + Returns: + A function, that resizes an image and preserves its aspect ratio. + + Note: + backwards-compat for "area"+antialias tested here: + (internal link) + """ + + def _resize_small(image): # pylint: disable=missing-docstring + h, w = tf.shape(image)[0], tf.shape(image)[1] + factor = ( + tf.cast(smaller_size, tf.float32) / + tf.cast(tf.minimum(h, w), tf.float32)) + return _resize_factor(image, factor, method=method, antialias=antialias) + return _resize_small + + +@Registry.register("preprocess_ops.resize_long") +@utils.InKeyOutKey() +def get_resize_long(longer_size, method="area", antialias=True): + """Resizes the longer side to `longer_size` keeping aspect ratio. + + Args: + longer_size: an integer, that represents a new size of the longer side of + an input image. + method: the resize method. `area` is a meaningful, bwd-compat default. + antialias: see tf.image.resize. Ideally set to True for all new configs. + + Returns: + A function, that resizes an image and preserves its aspect ratio. + """ + + def _resize_long(image): # pylint: disable=missing-docstring + h, w = tf.shape(image)[0], tf.shape(image)[1] + factor = ( + tf.cast(longer_size, tf.float32) / + tf.cast(tf.maximum(h, w), tf.float32)) + return _resize_factor(image, factor, method=method, antialias=antialias) + return _resize_long + + +@Registry.register("preprocess_ops.inception_crop") +@utils.InKeyOutKey() +def get_inception_crop(size=None, area_min=5, area_max=100, + method="bilinear", antialias=False): + """Makes inception-style image crop. + + Inception-style crop is a random image crop (its size and aspect ratio are + random) that was used for training Inception models, see + https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf. + + Args: + size: Resize image to [size, size] after crop. + area_min: minimal crop area. + area_max: maximal crop area. + method: rezied method, see tf.image.resize docs for options. + antialias: see tf.image.resize. Ideally set to True for all new configs. + + Returns: + A function, that applies inception crop. + """ + + def _inception_crop(image): # pylint: disable=missing-docstring + begin, crop_size, _ = tf.image.sample_distorted_bounding_box( + tf.shape(image), + tf.zeros([0, 0, 4], tf.float32), + area_range=(area_min / 100, area_max / 100), + min_object_covered=0, # Don't enforce a minimum area. + use_image_if_no_bounding_boxes=True) + crop = tf.slice(image, begin, crop_size) + # Unfortunately, the above operation loses the depth-dimension. So we need + # to restore it the manual way. + crop.set_shape([None, None, image.shape[-1]]) + if size: + crop = get_resize(size, method, antialias)({"image": crop})["image"] + return crop + + return _inception_crop + + +@Registry.register("preprocess_ops.decode_jpeg_and_inception_crop") +@utils.InKeyOutKey() +def get_decode_jpeg_and_inception_crop(size=None, area_min=5, area_max=100, + ratio_min=0.75, ratio_max=1.33, + method="bilinear", antialias=False): + """Decode jpeg string and make inception-style image crop. + + Inception-style crop is a random image crop (its size and aspect ratio are + random) that was used for training Inception models, see + https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf. + + Args: + size: Resize image to [size, size] after crop. + area_min: minimal crop area. + area_max: maximal crop area. + ratio_min: minimal aspect ratio. + ratio_max: maximal aspect ratio. + method: rezied method, see tf.image.resize docs for options. + antialias: see tf.image.resize. Ideally set to True for all new configs. + + Returns: + A function, that applies inception crop. + """ + + def _inception_crop(image_data): # pylint: disable=missing-docstring + shape = tf.image.extract_jpeg_shape(image_data) + begin, crop_size, _ = tf.image.sample_distorted_bounding_box( + shape, + tf.zeros([0, 0, 4], tf.float32), + area_range=(area_min / 100, area_max / 100), + aspect_ratio_range=(ratio_min, ratio_max), + min_object_covered=0, # Don't enforce a minimum area. + use_image_if_no_bounding_boxes=True) + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(begin) + target_height, target_width, _ = tf.unstack(crop_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + image = tf.image.decode_and_crop_jpeg(image_data, crop_window, channels=3) + + if size: + image = get_resize(size, method, antialias)({"image": image})["image"] + + return image + + return _inception_crop + + +@Registry.register("preprocess_ops.random_crop") +@utils.InKeyOutKey() +def get_random_crop(crop_size): + """Makes a random crop of a given size. + + Args: + crop_size: either an integer H, where H is both the height and width of the + random crop, or a list or tuple [H, W] of integers, where H and W are + height and width of the random crop respectively. + + Returns: + A function, that applies random crop. + """ + crop_size = utils.maybe_repeat(crop_size, 2) + + def _crop(image): + return tf.image.random_crop(image, (*crop_size, image.shape[-1])) + + return _crop + + +@Registry.register("preprocess_ops.central_crop") +@utils.InKeyOutKey() +def get_central_crop(crop_size=None): + """Makes central crop of a given size. + + Args: + crop_size: either an integer H, where H is both the height and width of the + central crop, or a list or tuple [H, W] of integers, where H and W are + height and width of the central crop respectively. If `crop_size` is not + specified, then the largest possible center crop will be taken. + + Returns: + A function, that applies central crop. + """ + if crop_size: + crop_size = utils.maybe_repeat(crop_size, 2) + + def _crop(image): + if crop_size: + h, w = crop_size[0], crop_size[1] + else: + h = w = tf.minimum(tf.shape(image)[0], tf.shape(image)[1]) + dy = (tf.shape(image)[0] - h) // 2 + dx = (tf.shape(image)[1] - w) // 2 + return tf.image.crop_to_bounding_box(image, dy, dx, h, w) + + return _crop + + +@Registry.register("preprocess_ops.flip_lr") +@utils.InKeyOutKey() +def get_random_flip_lr(): + """Flips an image horizontally with probability 50%.""" + + def _random_flip_lr_pp(image): + return tf.image.random_flip_left_right(image) + + return _random_flip_lr_pp + + +@Registry.register("preprocess_ops.vgg_value_range") +@utils.InKeyOutKey() +def get_vgg_value_range( + mean=(0.485 * 255, 0.456 * 255, 0.406 * 255), + std=(0.229 * 255, 0.224 * 255, 0.225 * 255), +): + """VGG-style preprocessing, subtracts mean and divides by stddev. + + This preprocessing is very common for ImageNet pre-trained models since VGG, + and to this day the standard for models coming from most PyTorch codes. + + Args: + mean: Tuple of values to be subtracted. Default to widespread VGG values. + std: Tuple of values to be divided by. Default to widespread VGG values. + + Returns: + A function to rescale the values. + """ + mean = tf.constant(mean, tf.float32) + std = tf.constant(std, tf.float32) + + def _vgg_value_range(image): + return (tf.cast(image, tf.float32) - mean) / std + return _vgg_value_range + + +@Registry.register("preprocess_ops.clip_value_range") +@utils.InKeyOutKey() +def get_clip_value_range(): + mean = (0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255) + std = (0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255) + + def _clip_value_range(image): + return (tf.cast(image, tf.float32) - mean) / std + return _clip_value_range + + +@Registry.register("preprocess_ops.convert_to_video") +@utils.InKeyOutKey() +def get_convert_to_video(num_frames): + """Converts an image to a video with zero padded frames. + + Args: + num_frames: total number of frames that the video should have. + + Returns: + A function for converting an image to a video. + """ + + def _convert_to_video(image): + return tf.pad( + tf.expand_dims(image, axis=0), + [[0, num_frames - 1], [0, 0], [0, 0], [0, 0]], + ) + + return _convert_to_video diff --git a/big_vision/pp/ops_image_test.py b/big_vision/pp/ops_image_test.py new file mode 100644 index 0000000000000000000000000000000000000000..080fe673cf90f83b405106dd057870ee8e8f76a2 --- /dev/null +++ b/big_vision/pp/ops_image_test.py @@ -0,0 +1,82 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ops_image.""" + +import copy +import io + +import big_vision.pp.ops_image as pp +import matplotlib.pyplot as plt +import numpy as np +import tensorflow as tf + + +def get_image_data(): + img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32) # Can't ask uint8!? + return {"image": tf.cast(img, tf.uint8)} + + +class PreprocessOpsTest(tf.test.TestCase): + + def tfrun(self, ppfn, data): + # Run once as standalone, as could happen eg in colab. + yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()} + + # And then once again as part of tfdata pipeline. + # You'd be surprised how much these two differ! + tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) + for npdata in tfdata.map(ppfn).as_numpy_iterator(): + yield npdata + + def test_resize(self): + for data in self.tfrun(pp.get_resize([120, 80]), get_image_data()): + self.assertEqual(data["image"].shape, (120, 80, 3)) + + def test_resize_small(self): + for data in self.tfrun(pp.get_resize_small(240), get_image_data()): + self.assertEqual(data["image"].shape, (320, 240, 3)) + + def test_resize_long(self): + for data in self.tfrun(pp.get_resize_long(320), get_image_data()): + self.assertEqual(data["image"].shape, (320, 240, 3)) + + def test_inception_crop(self): + for data in self.tfrun(pp.get_inception_crop(), get_image_data()): + self.assertEqual(data["image"].shape[-1], 3) + + def test_decode_jpeg_and_inception_crop(self): + f = io.BytesIO() + plt.imsave(f, get_image_data()["image"].numpy(), format="jpg") + data = {"image": tf.cast(f.getvalue(), tf.string)} + for data in self.tfrun(pp.get_decode_jpeg_and_inception_crop(), data): + self.assertEqual(data["image"].shape[-1], 3) + + def test_random_crop(self): + for data in self.tfrun(pp.get_random_crop([120, 80]), get_image_data()): + self.assertEqual(data["image"].shape, (120, 80, 3)) + + def test_central_crop(self): + for data in self.tfrun(pp.get_central_crop([20, 80]), get_image_data()): + self.assertEqual(data["image"].shape, (20, 80, 3)) + + def test_random_flip_lr(self): + data_orig = get_image_data() + for data in self.tfrun(pp.get_random_flip_lr(), data_orig): + self.assertTrue( + np.all(data_orig["image"].numpy() == data["image"]) or + np.all(data_orig["image"].numpy() == data["image"][:, ::-1])) + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/pp/ops_text.py b/big_vision/pp/ops_text.py new file mode 100644 index 0000000000000000000000000000000000000000..5ff8bdc3dae1197c7796ec5416c23052d61bef4e --- /dev/null +++ b/big_vision/pp/ops_text.py @@ -0,0 +1,411 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text-centric preprocessing ops. + +All preprocessing ops should return a data processing functors. A data +is represented as a dictionary of (TF) tensors. The functors output a modified +dictionary. + +A commonly used key for the tokenized output is "labels". +""" +import functools +import importlib +import string + +from absl import logging +from big_vision.datasets.imagenet import class_names as imagenet_class_names +from big_vision.pp import ops_general +from big_vision.pp import tokenizer as bv_tok +from big_vision.pp import utils +from big_vision.pp.registry import Registry +import tensorflow as tf + +from tensorflow.io import gfile + +import sentencepiece +SPProcessor = sentencepiece.SentencePieceProcessor + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' +import sentencepiece.sentencepiece_model_pb2 +del os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] +SPModelProto = sentencepiece.sentencepiece_model_pb2.ModelProto + + +# TODO: b/lbeyer - softly introduce and move to new tokenizer API. + +KNOWN_TOKENIZERS = { + "mc4": # used in multilingual models (mT5, PaLI), vocab_size=250_000 + "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model", + "cc_all": # vocab_size=32_000 + "gs://t5-data/vocabs/cc_all.32000/sentencepiece.model", + "c4_en": # vocab_size=32_000 + "gs://t5-data/vocabs/cc_en.32000/sentencepiece.model", + "t5": # same as cc_all, but with 100 extra dummy tokens used by T5 models + "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model", + "mt5": # same as mc4, but with 100 extra dummy tokens used by T5 models + "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model", +} + + +def create_tokenizer(model="c4_en", add_eos=True, add_bos=False): + """Creates a tokenizer which can be used in tfds.""" + logging.info("Creating tokenizer: %s", model) + with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f: + model = f.read() + + # Lazy import of tensorflow_text so it is an optional dependency for + # the users of this file. + import tensorflow_text + return tensorflow_text.SentencepieceTokenizer( + model=model, add_eos=add_eos, add_bos=add_bos + ) + + +def tokenize(input_text, tokenizer, max_len, *, pad_value, force_eos, + multi_text=False): + """Tokenizes string, and adds `pad_value` if longer than `max_len`.""" + + def pad(tokens): + # Truncate/pad to max_len. + if force_eos: + tokens = tf.cond( + tf.shape(tokens)[0] >= max_len, + lambda: tf.concat( + # For too long, cut them off, but do keep the final EOS token. + [tokens[:max_len - 1], tokens[-1:]], axis=0), + lambda: tf.pad( + tokens, [(0, max_len - tf.shape(tokens)[0])], + constant_values=pad_value), + ) + else: + tokens = tokens[:max_len] + tokens = tf.pad( + tokens, [(0, max_len - tf.shape(tokens)[0])], + constant_values=pad_value) + tokens.set_shape([max_len]) + return tokens + + tokens = tokenizer.tokenize(input_text) + + if multi_text: + tokens = tokens.to_tensor(pad_value) # tf.RaggedTensor to tf.Tensor + tokens = tf.reshape(tokens, [-1, tf.shape(tokens)[-1]]) + tokens = tf.map_fn(pad, tokens) # `map_fn` only maps on axis 0 + + final_shape = tf.concat([tf.shape(input_text), [max_len]], axis=0) + return tf.reshape(tokens, final_shape) + else: + return pad(tokens) + + +@Registry.register("preprocess_ops.tokenize") +@utils.InKeyOutKey(indefault=None, outdefault="labels") +def get_pp_tokenize( + max_len, + eos, + model="c4_en", + lower=True, + sample_if_multi=True, + pad_value="", + add_bos=False +): + """Tokenizes a text. + + Let's assume max_len=3 and id("")=1, id("a")=2, then we have + + 1. `eos="none", pad_value=0`: + - "a" -> [2, 0, 0] + - "aa" -> [2, 2, 0] + - "aaa" -> [2, 2, 2] + + 2. `eos="yes", pad_value=0`: + - "a" -> [2, 1, 0] + - "aa" -> [2, 2, 1] + - "aaa" -> [2, 2, 2] + + This is usually used with generative models that need to learn when to + properly predict a "" (when the sentence is finished) and when to + abstain (when the sentence is truncated). + + 3. `eos="sticky", pad_value=0`: + - "a" -> [2, 1, 0] + - "aa" -> [2, 2, 1] + - "aaa" -> [2, 2, 1] + + 4. `eos="sticky", pad_value=1`: + - "a" -> [2, 1, 1] + - "aa" -> [2, 2, 1] + - "aaa" -> [2, 2, 1] + + This is traditionally used with contrastive models that use the last token + for embeddings, similarly to "cls" tokens in BERT-style models. + + Args: + max_len: maximum length of the tokenized text. + eos: Whether to add an "" (end of sentence) token and whether to keep it + when the sequence is longer than `max_len - 1`. See examples above for + details. Valid values: "none", "yes", "sticky". + model: a path to the pretrained sentencepiece model. + lower: lowercase the text before tokenizing. + sample_if_multi: If there's more than one, randomly pick one if this is + True; otherwise pick all texts and keep the input's batch shape in result. + pad_value: which token to pad the sequence with. If a string (for example + `""`), tokenize it and use its first token. Note that there is no + guarantee to have any padding at the end of the sentence, if the sentence + is longer than `max_len`. + add_bos: adds beginning of sentence symbol. + + Returns: + an op that outputs tokenized text. + """ + + if eos not in ("yes", "none", "sticky"): + raise ValueError(f"Invalid value for eos: '{eos}'.") + + tokenizer = create_tokenizer(model, add_eos=eos != "none", add_bos=add_bos) + + if isinstance(pad_value, str): + pad_value = tokenizer.string_to_id(pad_value) + + def _pp_tokenize(txt): + if sample_if_multi and tf.convert_to_tensor(txt).ndim: + # TODO: I wish this code-path could die. + logging.warning("sample_if_multi is deprecated and will be removed." + "Call `choice` (and maybe `setdefault`) instead.") + txt = ops_general.get_choice(key="t")( + ops_general.get_setdefault("t", "")({"t": txt}))["t"] + + if lower: + txt = tf.strings.lower(txt) if sample_if_multi else tf.map_fn( + tf.strings.lower, txt) + + return tokenize( + txt, + tokenizer, + max_len, + pad_value=pad_value, + force_eos=eos == "sticky", + multi_text=not sample_if_multi) + + return _pp_tokenize + + +@Registry.register("preprocess_ops.coco_captions") +def get_coco_captions(outkey="captions"): + """Extracts coco's captions from nested dict.""" + + def _pp_coco_captions(data): + data[outkey] = data["captions"]["text"] + return data + + return _pp_coco_captions + + +@Registry.register("preprocess_ops.clip_i1k_label_names") +@utils.InKeyOutKey(indefault="label", outdefault="labels") +def get_pp_clip_i1k_label_names(): + """Convert i1k label numbers to strings, using CLIP's class names.""" + + def _pp_imagenet_labels(label): + return tf.gather(imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, label) + + return _pp_imagenet_labels + + +@Registry.register("preprocess_ops.i21k_label_names") +@utils.InKeyOutKey(indefault="label", outdefault="labels") +def get_pp_i21k_label_names(): + """Converts i21k label ids to strings.""" + + def _pp_imagenet_labels(label): + return tf.gather(imagenet_class_names.IMAGENET21k_CLASS_NAMES, label) + + return _pp_imagenet_labels + + +@Registry.register("preprocess_ops.lower") +@utils.InKeyOutKey(indefault="text", outdefault="text") +def get_lower(): + """Lowercases text feature.""" + + def _pp_lower(text): + return tf.strings.lower(text) + + return _pp_lower + + +@Registry.register("preprocess_ops.strfmt") +def get_strfmt(template, outkey="text"): + """Formats a string template with content form the data dict.""" + + def _template(data): + outputs = [] + parts = string.Formatter().parse(template) + for (literal_text, field_name, format_spec, conversion) in parts: + # For now, we keep it simple and don't support fancy format specs. + # But we can add support to that via py_func as soon as we need it. + assert not format_spec and not conversion + outputs.append(tf.constant(literal_text)) + if field_name: + value = data[field_name] + # Convert any non-strings (numbers, vectors) to a string. + if tf.convert_to_tensor(value).dtype != tf.string: + value = tf.strings.format("{}", value, summarize=-1) + outputs.append(value) + data[outkey] = tf.strings.join(outputs) + return data + + return _template + + +def _add_pieces(model_bytes, extra_pieces): + """Adds extra pieces to sentencpiece model specified by `model_bytes`.""" + + model = SPProcessor() + model.LoadFromSerializedProto(model_bytes) + unk_idx = model.PieceToId("") + assert model.IdToPiece(unk_idx) == "", model.IdToPiece(unk_idx) + + model_proto = SPModelProto.FromString(model_bytes) + idx_to_updated_piece = {} + for piece in extra_pieces: + # The SentencePieceModel proto stores whitespaces as the special + # character '▁'. We perform the conversion here. + piece = piece.replace(" ", "▁") + spiece = model_proto.SentencePiece( + piece=piece, + # We set the highest score to force priority on user defined tokens. + score=0.0, + type=model_proto.SentencePiece().Type.USER_DEFINED, + ) + existing_idx = model.PieceToId(piece) + if (existing_idx != unk_idx) ^ (piece == ""): + idx_to_updated_piece[existing_idx] = spiece + logging.info("Updating token at idx %d: %s", existing_idx, spiece.piece) + else: + model_proto.pieces.append(spiece) + + # Replace duplicated pieces with updated ones. + updated_pieces = [ + idx_to_updated_piece.get(i, piece) + for i, piece in enumerate(model_proto.pieces) + ] + del model_proto.pieces[:] + model_proto.pieces.extend(updated_pieces) + + return model_proto.SerializeToString() + + +def _iterable(x): + if isinstance(x, tf.RaggedTensor): + return True + if getattr(x, "ndim", 0) > 1: # np, jnp + return True + if isinstance(x, (list, tuple)) and not isinstance(x[0], (int, float)): + return True + return False + + +@Registry.register("tokenizers.sp") +class SentencepieceTokenizer(bv_tok.Tokenizer): + """Wraps a `tftext.SentencepieceTokenizer`. + + If you plan to use this tokenizer, please familiarize yourself with the test + cases first. This is likely to save you a lot of troubles down the road, trust + me! + """ + + def __init__(self, model, tokensets=()): + with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f: + model_bytes = f.read() + extras = bv_tok.get_extra_tokens(tokensets) + model_bytes = _add_pieces(model_bytes, extras) + self._tok_sp = SPProcessor() + self._tok_sp.LoadFromSerializedProto(model_bytes) + self.extras = {self._tok_sp.PieceToId(x): x for x in extras} + + def to_int(self, text, *, bos=False, eos=False): + def _single(s): + return ( + ([self.bos_token] if bos else []) + + self._tok_sp.EncodeAsIds(s) + + ([self.eos_token] if eos else []) + ) + if isinstance(text, str): + return _single(text) + return type(text)([_single(s) for s in text]) + + def to_str(self, tokens, *, stop_at_eos=True): + def _single(toks): + toks = [int(t) for t in toks] # We really need this for DecodeIds. + if stop_at_eos: + try: # The SentencePiece strips eos, but does not stop at it, so we do. + toks = toks[:toks.index(self.eos_token)] + except ValueError: # No eos token found, nothing to do. + pass + return self._tok_sp.DecodeIds(toks) + if _iterable(tokens): + return [_single(toks) for toks in tokens] + return _single(tokens) + + def _check_known(self, piece): + if (id_ := self._tok_sp.PieceToId(piece)) == self._tok_sp.unk_id(): + logging.error("Piece '%s' is not known (unk=%s)!", piece, id_) + return id_ + + def to_piece(self, idx): + return self._tok_sp.IdToPiece(int(idx)) + + @property + def pad_token(self): + return self._tok_sp.pad_id() + + @property + def eos_token(self): + return self._tok_sp.eos_id() + + @property + def bos_token(self): + return self._tok_sp.bos_id() + + @property + def vocab_size(self): + return self._tok_sp.GetPieceSize() + + # For the _tf_op variants, we need a lot of wrapping boilerplate. + + def to_int_tf_op(self, text, *, bos=False, eos=False): + text = tf.convert_to_tensor(text) + if text.ndim == 0: + def fn(txt): + s = txt.numpy().decode() + return tf.constant(self.to_int(s, bos=bos, eos=eos), tf.int32) + return tf.py_function(fn, [text], tf.int32) + else: + def fn(txt): + strings = [s.decode() for s in txt.numpy().tolist()] + toks = self.to_int(strings, bos=bos, eos=eos) + return tf.ragged.constant(toks) + out_type = tf.RaggedTensorSpec([tf.shape(text)[0], None], tf.int32) + return tf.py_function(fn, [text], Tout=out_type) + + def to_str_tf_op(self, tokens, *, stop_at_eos=True): + def single(t): + fn = functools.partial(self.to_str, stop_at_eos=stop_at_eos) + return tf.numpy_function(fn, [t], tf.string, stateful=False) + if _iterable(tokens): + return tf.map_fn(single, tokens, tf.string) + return single(tokens) diff --git a/big_vision/pp/ops_text_test.py b/big_vision/pp/ops_text_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf09d8e35bab0a0e5daec1ead7583d4b02cb82f --- /dev/null +++ b/big_vision/pp/ops_text_test.py @@ -0,0 +1,200 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ops_text.""" + +import copy + +from absl.testing import parameterized +import big_vision.pp.ops_text as pp +from big_vision.pp.registry import Registry +import numpy as np +import tensorflow as tf + + +class PyToTfWrapper: + """Allows to use `to_{int,str}_tf()` via `to_{int,str}()`.""" + + def __init__(self, tok): + self.tok = tok + self.bos_token = tok.bos_token + self.eos_token = tok.eos_token + self.vocab_size = tok.vocab_size + + def to_int(self, text, *, bos=False, eos=False): + ret = self.tok.to_int_tf_op(text, bos=bos, eos=eos) + if isinstance(ret, tf.RaggedTensor): + return [t.numpy().tolist() for t in ret] + return ret.numpy().tolist() + + def to_str(self, tokens, stop_at_eos=True): + ret = self.tok.to_str_tf_op( + tf.ragged.constant(tokens), + stop_at_eos=stop_at_eos, + ) + if ret.ndim == 0: + return ret.numpy().decode() + return [t.numpy().decode() for t in ret] + + +class PpOpsTest(tf.test.TestCase, parameterized.TestCase): + + def tfrun(self, ppfn, data): + # Run once as standalone, as could happen eg in colab. + yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()} + + # And then once again as part of tfdata pipeline. + # You'd be surprised how much these two differ! + tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) + for npdata in tfdata.map(ppfn).as_numpy_iterator(): + yield npdata + + def testtok(self): + # https://github.com/google/sentencepiece/blob/master/python/test/test_model.model + return "test_model.model" # Should we just commit it? It's 200kB + + def test_get_pp_clip_i1k_label_names(self): + op = pp.get_pp_clip_i1k_label_names() + labels = op({"label": tf.constant([0, 1])})["labels"].numpy().tolist() + self.assertAllEqual(labels, ["tench", "goldfish"]) + + def test_get_pp_i21k_label_names(self): + op = pp.get_pp_i21k_label_names() + labels = op({"label": tf.constant([0, 1])})["labels"].numpy().tolist() + self.assertAllEqual(labels, ["organism", "benthos"]) + + @parameterized.parameters((b"Hello world ScAlAr!", b"hello world scalar!"), + (["Decoded Array!"], ["decoded array!"]), + ([b"aA", "bB"], [b"aa", "bb"])) + def test_get_lower(self, inputs, expected_output): + op = pp.get_lower() + out = op({"text": tf.constant(inputs)}) + self.assertAllEqual(out["text"].numpy(), np.array(expected_output)) + + @parameterized.named_parameters( + ("py", False), + ("tf", True), + ) + def test_sentencepiece_tokenizer(self, wrap_tok): + tok = pp.SentencepieceTokenizer(self.testtok()) + if wrap_tok: + tok = PyToTfWrapper(tok) + self.assertEqual(tok.vocab_size, 1000) + bos, eos = tok.bos_token, tok.eos_token + self.assertEqual(bos, 1) + self.assertEqual(eos, 2) + # Note: test model does NOT have a token (similar to e.g. "mistral"). + # `.to_int()` wraps `.to_int_tf_ops` which is thus also tested + self.assertEqual(tok.to_int("blah"), [80, 180, 60]) + self.assertEqual(tok.to_int("blah", bos=True), [bos, 80, 180, 60]) + self.assertEqual(tok.to_int("blah", eos=True), [80, 180, 60, eos]) + self.assertEqual( + tok.to_int("blah", bos=True, eos=True), [bos, 80, 180, 60, eos] + ) + self.assertEqual( + tok.to_int(["blah", "blah blah"]), + [[80, 180, 60], [80, 180, 60, 80, 180, 60]], + ) + # inverse of above + # `.to_str()` wraps `.to_str_tf_ops` which is thus also tested + self.assertEqual(tok.to_str([80, 180, 60]), "blah") + self.assertEqual(tok.to_str([1, 80, 180, 60]), "blah") + self.assertEqual(tok.to_str([80, 180, 60, 2]), "blah") + self.assertEqual( + tok.to_str([[80, 180, 60], [80, 180, 60, 80, 180, 60]]), + ["blah", "blah blah"], + ) + + def test_sentencepiece_tokenizer_tf_op_ndarray_input(self): + tok = pp.SentencepieceTokenizer(self.testtok()) + bos, eos = tok.bos_token, tok.eos_token + arr = np.array([[bos, 80, 180, 60, eos]] * 2, dtype=np.int32) + self.assertEqual(tok.to_str_tf_op(arr).numpy().tolist(), [b"blah"] * 2) + + def test_sentencepiece_tokenizer_tokensets(self): + tok = pp.SentencepieceTokenizer(self.testtok(), tokensets=["loc"]) + self.assertEqual(tok.vocab_size, 2024) + self.assertEqual( + tok.to_int("blah"), [80, 180, 60, 1000, 2023] + ) + + def test_sentencepiece_stop_at_eos(self): + tok = pp.SentencepieceTokenizer(self.testtok()) + self.assertEqual(tok.to_str([80, 180, 60], stop_at_eos=False), "blah") + eos = tok.eos_token + self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=False), "blah") + self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=True), "b") + self.assertEqual( + tok.to_str([[80, eos, 180, 60], [80, 180, eos, 60]], stop_at_eos=True), + ["b", "bla"] + ) + + def test_sentencepiece_extra_tokens(self): + tok = pp.SentencepieceTokenizer(self.testtok()) + self.assertEqual(tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), "blah") + tok = pp.SentencepieceTokenizer( + self.testtok(), tokensets=["sp_extra_tokens"] + ) + self.assertEqual(tok.vocab_size, 1001) # Also added the token. + self.assertEqual( + tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), " blah" + ) + + def test_strfmt(self): + data = { + "int": tf.constant(42, tf.uint8), + "float": tf.constant(3.14, tf.float32), + "vec": tf.range(3), + "empty_str": tf.constant(""), + "regex_problem1": tf.constant(r"no \replace pattern"), + "regex_problem2": tf.constant(r"yes \1 pattern"), + } + for out in self.tfrun(pp.get_strfmt("Nothing"), data): + self.assertEqual(out["text"], b"Nothing") + for out in self.tfrun(pp.get_strfmt("{int}"), data): + self.assertEqual(out["text"], b"42") + for out in self.tfrun(pp.get_strfmt("A{int}"), data): + self.assertEqual(out["text"], b"A42") + for out in self.tfrun(pp.get_strfmt("{int}A"), data): + self.assertEqual(out["text"], b"42A") + for out in self.tfrun(pp.get_strfmt("{int}{int}"), data): + self.assertEqual(out["text"], b"4242") + for out in self.tfrun(pp.get_strfmt("A{int}A{int}A"), data): + self.assertEqual(out["text"], b"A42A42A") + for out in self.tfrun(pp.get_strfmt("A{float}A"), data): + self.assertEqual(out["text"], b"A3.14A") + for out in self.tfrun(pp.get_strfmt("A{float}A{int}"), data): + self.assertEqual(out["text"], b"A3.14A42") + for out in self.tfrun(pp.get_strfmt("A{vec}A"), data): + self.assertEqual(out["text"], b"A[0 1 2]A") + for out in self.tfrun(pp.get_strfmt("A{empty_str}A"), data): + self.assertEqual(out["text"], b"AA") + for out in self.tfrun(pp.get_strfmt("{empty_str}"), data): + self.assertEqual(out["text"], b"") + for out in self.tfrun(pp.get_strfmt("A{regex_problem1}A"), data): + self.assertEqual(out["text"], br"Ano \replace patternA") + for out in self.tfrun(pp.get_strfmt("A{regex_problem2}A"), data): + self.assertEqual(out["text"], br"Ayes \1 patternA") + + +@Registry.register("tokensets.sp_extra_tokens") +def _get_sp_extra_tokens(): + # For sentencepiece, adding these tokens will make them visible when decoding. + # If a token is not found (e.g. "" is not found in "mistral"), then it is + # added to the vocabulary, increasing the vocab_size accordingly. + return ["", "", ""] + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/pp/registry.py b/big_vision/pp/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c7d996d756be16ba68a5fcb143f23129e1249d --- /dev/null +++ b/big_vision/pp/registry.py @@ -0,0 +1,163 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Global Registry for big_vision pp ops. + +Author: Joan Puigcerver (jpuigcerver@) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ast +import contextlib +import functools + + +def parse_name(string_to_parse): + """Parses input to the registry's lookup function. + + Args: + string_to_parse: can be either an arbitrary name or function call + (optionally with positional and keyword arguments). + e.g. "multiclass", "resnet50_v2(filters_factor=8)". + + Returns: + A tuple of input name, argument tuple and a keyword argument dictionary. + Examples: + "multiclass" -> ("multiclass", (), {}) + "resnet50_v2(9, filters_factor=4)" -> + ("resnet50_v2", (9,), {"filters_factor": 4}) + + Author: Joan Puigcerver (jpuigcerver@) + """ + expr = ast.parse(string_to_parse, mode="eval").body # pytype: disable=attribute-error + if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)): + raise ValueError( + "The given string should be a name or a call, but a {} was parsed from " + "the string {!r}".format(type(expr), string_to_parse)) + + # Notes: + # name="some_name" -> type(expr) = ast.Name + # name="module.some_name" -> type(expr) = ast.Attribute + # name="some_name()" -> type(expr) = ast.Call + # name="module.some_name()" -> type(expr) = ast.Call + + if isinstance(expr, ast.Name): + return string_to_parse, (), {} + elif isinstance(expr, ast.Attribute): + return string_to_parse, (), {} + + def _get_func_name(expr): + if isinstance(expr, ast.Attribute): + return _get_func_name(expr.value) + "." + expr.attr + elif isinstance(expr, ast.Name): + return expr.id + else: + raise ValueError( + "Type {!r} is not supported in a function name, the string to parse " + "was {!r}".format(type(expr), string_to_parse)) + + def _get_func_args_and_kwargs(call): + args = tuple([ast.literal_eval(arg) for arg in call.args]) + kwargs = { + kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords + } + return args, kwargs + + func_name = _get_func_name(expr.func) + func_args, func_kwargs = _get_func_args_and_kwargs(expr) + + return func_name, func_args, func_kwargs + + +class Registry(object): + """Implements global Registry. + + Authors: Joan Puigcerver (jpuigcerver@), Alexander Kolesnikov (akolesnikov@) + """ + + _GLOBAL_REGISTRY = {} + + @staticmethod + def global_registry(): + return Registry._GLOBAL_REGISTRY + + @staticmethod + def register(name, replace=False): + """Creates a function that registers its input.""" + + def _register(item): + if name in Registry.global_registry() and not replace: + raise KeyError("The name {!r} was already registered.".format(name)) + + Registry.global_registry()[name] = item + return item + + return _register + + @staticmethod + def lookup(lookup_string, kwargs_extra=None): + """Lookup a name in the registry.""" + + try: + name, args, kwargs = parse_name(lookup_string) + except ValueError as e: + raise ValueError(f"Error parsing:\n{lookup_string}") from e + if kwargs_extra: + kwargs.update(kwargs_extra) + item = Registry.global_registry()[name] + return functools.partial(item, *args, **kwargs) + + @staticmethod + def knows(lookup_string): + try: + name, _, _ = parse_name(lookup_string) + except ValueError as e: + raise ValueError(f"Error parsing:\n{lookup_string}") from e + return name in Registry.global_registry() + + +@contextlib.contextmanager +def temporary_ops(**kw): + """Registers specified pp ops for use in a `with` block. + + Example use: + + with pp_registry.remporary_ops( + pow=lambda alpha: lambda d: {k: v**alpha for k, v in d.items()}): + pp = pp_builder.get_preprocess_fn("pow(alpha=2.0)|pow(alpha=0.5)") + features = pp(features) + + Args: + **kw: Names are preprocess string function names to be used to specify the + preprocess function. Values are functions that can be called with params + (e.g. the `alpha` param in above example) and return functions to be used + to transform features. + + Yields: + A context manager to be used in a `with` statement. + """ + reg = Registry.global_registry() + kw = {f"preprocess_ops.{k}": v for k, v in kw.items()} + for k in kw: + assert k not in reg + for k, v in kw.items(): + reg[k] = v + try: + yield + finally: + for k in kw: + del reg[k] diff --git a/big_vision/pp/registry_test.py b/big_vision/pp/registry_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2296e7de91ce0495bade59e8e65417384507e58e --- /dev/null +++ b/big_vision/pp/registry_test.py @@ -0,0 +1,128 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for registry.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from unittest import mock + +from absl.testing import absltest +from big_vision.pp import registry + + +class RegistryTest(absltest.TestCase): + + def setUp(self): + super(RegistryTest, self).setUp() + # Mock global registry in each test to keep them isolated and allow for + # concurrent tests. + self.addCleanup(mock.patch.stopall) + self.global_registry = dict() + self.mocked_method = mock.patch.object( + registry.Registry, "global_registry", + return_value=self.global_registry).start() + + def test_parse_name(self): + name, args, kwargs = registry.parse_name("f") + self.assertEqual(name, "f") + self.assertEqual(args, ()) + self.assertEqual(kwargs, {}) + + name, args, kwargs = registry.parse_name("f()") + self.assertEqual(name, "f") + self.assertEqual(args, ()) + self.assertEqual(kwargs, {}) + + name, args, kwargs = registry.parse_name("func(a=0,b=1,c='s')") + self.assertEqual(name, "func") + self.assertEqual(args, ()) + self.assertEqual(kwargs, {"a": 0, "b": 1, "c": "s"}) + + name, args, kwargs = registry.parse_name("func(1,'foo',3)") + self.assertEqual(name, "func") + self.assertEqual(args, (1, "foo", 3)) + self.assertEqual(kwargs, {}) + + name, args, kwargs = registry.parse_name("func(1,'2',a=3,foo='bar')") + self.assertEqual(name, "func") + self.assertEqual(args, (1, "2")) + self.assertEqual(kwargs, {"a": 3, "foo": "bar"}) + + name, args, kwargs = registry.parse_name("foo.bar.func(a=0,b=(1),c='s')") + self.assertEqual(name, "foo.bar.func") + self.assertEqual(kwargs, dict(a=0, b=1, c="s")) + + with self.assertRaises(SyntaxError): + registry.parse_name("func(0") + with self.assertRaises(SyntaxError): + registry.parse_name("func(a=0,,b=0)") + with self.assertRaises(SyntaxError): + registry.parse_name("func(a=0,b==1,c='s')") + with self.assertRaises(ValueError): + registry.parse_name("func(a=0,b=undefined_name,c='s')") + + def test_register(self): + # pylint: disable=unused-variable + @registry.Registry.register("func1") + def func1(): + pass + + self.assertLen(registry.Registry.global_registry(), 1) + + def test_lookup_function(self): + + @registry.Registry.register("func1") + def func1(arg1, arg2, arg3): # pylint: disable=unused-variable + return arg1, arg2, arg3 + + self.assertTrue(callable(registry.Registry.lookup("func1"))) + self.assertEqual(registry.Registry.lookup("func1")(1, 2, 3), (1, 2, 3)) + self.assertEqual( + registry.Registry.lookup("func1(arg3=9)")(1, 2), (1, 2, 9)) + self.assertEqual( + registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg3=3), (99, 9, 3)) + self.assertEqual( + registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg1=1, arg3=3), + (1, 9, 3)) + + self.assertEqual( + registry.Registry.lookup("func1(1)")(1, 2), (1, 1, 2)) + self.assertEqual( + registry.Registry.lookup("func1(1)")(arg3=3, arg2=2), (1, 2, 3)) + self.assertEqual( + registry.Registry.lookup("func1(1, 2)")(3), (1, 2, 3)) + self.assertEqual( + registry.Registry.lookup("func1(1, 2)")(arg3=3), (1, 2, 3)) + self.assertEqual( + registry.Registry.lookup("func1(1, arg2=2)")(arg3=3), (1, 2, 3)) + self.assertEqual( + registry.Registry.lookup("func1(1, arg3=2)")(arg2=3), (1, 3, 2)) + self.assertEqual( + registry.Registry.lookup("func1(1, arg3=2)")(3), (1, 3, 2)) + + with self.assertRaises(TypeError): + registry.Registry.lookup("func1(1, arg2=2)")(3) + with self.assertRaises(TypeError): + registry.Registry.lookup("func1(1, arg3=3)")(arg3=3) + with self.assertRaises(TypeError): + registry.Registry.lookup("func1(1, arg3=3)")(arg1=3) + with self.assertRaises(SyntaxError): + registry.Registry.lookup("func1(arg1=1, 3)")(arg2=3) + + +if __name__ == "__main__": + absltest.main() diff --git a/big_vision/pp/tokenizer.py b/big_vision/pp/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..681494e436aacd48d5d720e07d2df1a80c704eb2 --- /dev/null +++ b/big_vision/pp/tokenizer.py @@ -0,0 +1,103 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The tokenizer API for big_vision, and central registration place.""" +import functools +import importlib +from typing import Protocol + +from absl import logging +from big_vision.pp import registry +import big_vision.utils as u +import numpy as np + + +class Tokenizer(Protocol): + """Just to unify on the API as we now have mmany different ones.""" + + def to_int(self, text, *, bos=False, eos=False): + """Tokenizes `text` into a list of integer tokens. + + Args: + text: can be a single string, or a list of strings. + bos: Whether a beginning-of-sentence token should be prepended. + eos: Whether an end-of-sentence token should be appended. + + Returns: + List or list-of-list of tokens. + """ + + def to_int_tf_op(self, text, *, bos=False, eos=False): + """Same as `to_int()`, but as TF ops to be used in pp.""" + + def to_str(self, tokens, *, stop_at_eos=True): + """Inverse of `to_int()`. + + Args: + tokens: list of tokens, or list of lists of tokens. + stop_at_eos: remove everything that may come after the first EOS. + + Returns: + A string (if `tokens` is a list of tokens), or a list of strings. + Note that most tokenizers strip select few control tokens like + eos/bos/pad/unk from the output string. + """ + + def to_str_tf_op(self, tokens, *, stop_at_eos=True): + """Same as `to_str()`, but as TF ops to be used in pp.""" + + @property + def pad_token(self): + """Token id of padding token.""" + + @property + def eos_token(self): + """Token id of end-of-sentence token.""" + + @property + def bos_token(self): + """Token id of beginning-of-sentence token.""" + + @property + def vocab_size(self): + """Returns the size of the vocabulary.""" + + +@functools.cache +def get_tokenizer(name): + with u.chrono.log_timing(f"z/secs/tokenizer/{name}"): + if not registry.Registry.knows(f"tokenizers.{name}"): + raw_name, *_ = registry.parse_name(name) + logging.info("Tokenizer %s not registered, " + "trying import big_vision.pp.%s", name, raw_name) + importlib.import_module(f"big_vision.pp.{raw_name}") + + return registry.Registry.lookup(f"tokenizers.{name}")() + + +def get_extra_tokens(tokensets): + extra_tokens = [] + for tokenset in tokensets: + extra_tokens.extend(registry.Registry.lookup(f"tokensets.{tokenset}")()) + return list(np.unique(extra_tokens)) # Preserves order. Dups make no sense. + + +@registry.Registry.register("tokensets.loc") +def _get_loc1024(n=1024): + return [f"" for i in range(n)] + + +@registry.Registry.register("tokensets.seg") +def _get_seg(n=128): + return [f"" for i in range(n)] diff --git a/big_vision/pp/utils.py b/big_vision/pp/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ee834560246549c71f0a6d9785694fd1507ca9b --- /dev/null +++ b/big_vision/pp/utils.py @@ -0,0 +1,53 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preprocessing utils.""" + +from collections import abc + + +def maybe_repeat(arg, n_reps): + if not isinstance(arg, abc.Sequence) or isinstance(arg, str): + arg = (arg,) * n_reps + return arg + + +class InKeyOutKey(object): + """Decorator for preprocessing ops, which adds `inkey` and `outkey` arguments. + + Note: Only supports single-input single-output ops. + """ + + def __init__(self, indefault="image", outdefault="image", with_data=False): + self.indefault = indefault + self.outdefault = outdefault + self.with_data = with_data + + def __call__(self, orig_get_pp_fn): + + def get_ikok_pp_fn(*args, key=None, + inkey=self.indefault, outkey=self.outdefault, **kw): + + orig_pp_fn = orig_get_pp_fn(*args, **kw) + def _ikok_pp_fn(data): + # Optionally allow the function to get the full data dict as aux input. + if self.with_data: + data[key or outkey] = orig_pp_fn(data[key or inkey], data=data) + else: + data[key or outkey] = orig_pp_fn(data[key or inkey]) + return data + + return _ikok_pp_fn + + return get_ikok_pp_fn diff --git a/big_vision/pp/utils_test.py b/big_vision/pp/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..beec18cef62a9638ed143229d7aedc5e218a70b6 --- /dev/null +++ b/big_vision/pp/utils_test.py @@ -0,0 +1,53 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for preprocessing utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from big_vision.pp import utils +import tensorflow.compat.v1 as tf + + +class UtilsTest(tf.test.TestCase): + + def test_maybe_repeat(self): + self.assertEqual((1, 1, 1), utils.maybe_repeat(1, 3)) + self.assertEqual((1, 2), utils.maybe_repeat((1, 2), 2)) + self.assertEqual([1, 2], utils.maybe_repeat([1, 2], 2)) + + def test_inkeyoutkey(self): + @utils.InKeyOutKey() + def get_pp_fn(shift, scale=0): + def _pp_fn(x): + return scale * x + shift + return _pp_fn + + data = {"k_in": 2, "other": 3} + ppfn = get_pp_fn(1, 2, inkey="k_in", outkey="k_out") # pylint: disable=unexpected-keyword-arg + self.assertEqual({"k_in": 2, "k_out": 5, "other": 3}, ppfn(data)) + + data = {"k": 6, "other": 3} + ppfn = get_pp_fn(1, inkey="k", outkey="k") # pylint: disable=unexpected-keyword-arg + self.assertEqual({"k": 1, "other": 3}, ppfn(data)) + + data = {"other": 6, "image": 3} + ppfn = get_pp_fn(5, 2) + self.assertEqual({"other": 6, "image": 11}, ppfn(data)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/big_vision/requirements.txt b/big_vision/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ae71db5ac26e2746864c37959dfd28cb9fe70bf --- /dev/null +++ b/big_vision/requirements.txt @@ -0,0 +1,19 @@ +numpy>=1.26 +absl-py +git+https://github.com/google/CommonLoopUtils +distrax +editdistance +einops +flax +optax +git+https://github.com/google/flaxformer +git+https://github.com/akolesnikoff/panopticapi.git@mute +overrides +protobuf +sentencepiece +tensorflow-cpu +tfds-nightly +tensorflow-text +tensorflow-gan +psutil +pycocoevalcap diff --git a/big_vision/run_tpu.sh b/big_vision/run_tpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..3c3da2e44e7d2829a00188f5e6177ea9d6e3ba4d --- /dev/null +++ b/big_vision/run_tpu.sh @@ -0,0 +1,35 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/bin/bash + +if [ ! -d "bv_venv" ] +then + sudo apt-get update + sudo apt install -y python3-venv + python3 -m venv bv_venv + . bv_venv/bin/activate + + pip install -U pip # Yes, really needed. + # NOTE: doesn't work when in requirements.txt -> cyclic dep + pip install "jax[tpu]>=0.4.25" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip install -r big_vision/requirements.txt +else + . bv_venv/bin/activate +fi + +if [ $# -ne 0 ] +then + env TFDS_DATA_DIR=$TFDS_DATA_DIR BV_JAX_INIT=1 python3 -m "$@" +fi diff --git a/big_vision/sharding.py b/big_vision/sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..be76cb3a1f6b8bc0e494515bac2a54528a53494c --- /dev/null +++ b/big_vision/sharding.py @@ -0,0 +1,197 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Big vision sharding utilities.""" + +from absl import logging + +from big_vision.pp.registry import Registry +import big_vision.utils as u +import flax.linen as nn +import jax +import numpy as np + + +NamedSharding = jax.sharding.NamedSharding +P = jax.sharding.PartitionSpec + + +def _replicated(mesh): + return NamedSharding(mesh, P()) + + +def _shard_along_axis(mesh, i, axis_name): + return NamedSharding(mesh, P(*((None,) * i + (axis_name,)))) + + +def infer_sharding(params, strategy, mesh): + """Infers `params` sharding based on strategy. + + Args: + params: a pytree of arrays. + strategy: sharding strategy. + mesh: jax device mesh. + + Returns: + A pytree with shardings, that has the same shape as the `tree` argument. + """ + patterns, tactics = zip(*strategy) + + x_with_names, tree_def = u.tree_flatten_with_names(params) + names = tree_def.unflatten(list(zip(*x_with_names))[0]) + + # Follows big_vision conventions: each variable is matched at most once, + # early patterns get matching priority. + mask_trees = u.make_mask_trees(params, patterns) + + specs = jax.tree.map(lambda x: (None,) * x.ndim, params) + + for mask_tree, tactic in zip(mask_trees, tactics): + for op_str in tactic.split("|"): + op = Registry.lookup(f"shardings.{op_str}")() + specs = jax.tree.map( + lambda x, n, match, spec, op=op: op(spec, mesh, n, x) + if match else spec, + params, names, mask_tree, specs, + is_leaf=lambda v: isinstance(v, nn.Partitioned)) + + # Two-level tree_map to prevent it from doing traversal inside the spec. + specs = jax.tree.map(lambda _, spec: P(*spec), nn.unbox(params), specs) + return jax.tree.map(lambda spec: NamedSharding(mesh, spec), specs) + + +# Sharding rules +# +# Each rule needs to be added to the registry, can accept custom args, and +# returns a function that updates the current spec. The arguments are: +# 1. Variable name +# 2. Variable itself (or placeholder with .shape and .dtype properties) +# 3. The current sharing spec. + + +@Registry.register("shardings.replicate") +def replicate(): + """Full replication sharding rule. + + Note full replication is deafult, so this can be skipped and useful to + explicitly state in the config that certrain parameters are replicated. + TODO: can be generalized to support replication over a sub-mesh. + + Returns: + A function that updates the sharding spec. + """ + def _update_spec(cur_spec, mesh, name, x): + del x, mesh + if not all(axis is None for axis in cur_spec): + raise ValueError(f"Inconsistent sharding instructions: " + f"parameter {name} has spec {cur_spec}, " + f"so it can't be fully replicated.") + return cur_spec + return _update_spec + + +@Registry.register("shardings.fsdp") +def fsdp(axis, min_size_to_shard_mb=4): + """FSDP sharding rule. + + Shards the largest dimension that is not sharded already and is divisible + by the total device count. + + Args: + axis: mesh axis name for FSDP, or a collection of names. + min_size_to_shard_mb: minimal tensor size to bother with sharding. + + Returns: + A function that updates the sharding spec. + """ + axis = axis if isinstance(axis, str) else tuple(axis) + axis_tuple = axis if isinstance(axis, tuple) else (axis,) + def _update_spec(cur_spec, mesh, name, x): + shape = x.shape + axis_size = np.prod([mesh.shape[a] for a in axis_tuple]) + + if np.prod(shape) * x.dtype.itemsize <= min_size_to_shard_mb * (2 ** 20): + return cur_spec + + # Partition along largest axis that is divisible and not taken. + idx = np.argsort(shape)[::-1] + for i in idx: + if shape[i] % axis_size == 0: + if cur_spec[i] is None: + return cur_spec[:i] + (axis,) + cur_spec[i+1:] + + logging.info("Failed to apply `fsdp` rule to the parameter %s:%s, as all " + "its dimensions are not divisible by the requested axis: " + "%s:%i, or already occupied by other sharding rules: %s", + name, shape, axis, axis_size, cur_spec) + return cur_spec + return _update_spec + + +@Registry.register("shardings.logical_partitioning") +def logical_partitioning(): + """Manual sharding based on Flax's logical partitioning annotations. + + Uses logical sharding annotations added in model code with + `nn.with_logical_partitioning`. Respects logical to mesh name mapping rules + (typically defined in the dynamic context using + `with nn.logical_axis_rules(rules): ...`). + + Returns: + A function that outputs the sharding spec of `nn.LogicallyPartitioned` boxed + specs. + """ + def _update_spec(cur_spec, mesh, name, x): + del x, name, mesh + if isinstance(cur_spec, nn.LogicallyPartitioned): + return nn.logical_to_mesh_axes(cur_spec.names) + return cur_spec + return _update_spec + + +@Registry.register("shardings.shard_dim") +def shard_dim(axis, dim, ignore_ndim_error=False): + """Shards the given dimension along the given axis. + + Args: + axis: mesh axis name for sharding. + dim: dimension to shard (can be negative). + ignore_ndim_error: if True, a warning error is logged instead of raising an + exception when the given dimension is not compatible with the number of + dimensions of the array. + + Returns: + A function that updates the sharding spec. + """ + def _update_spec(cur_spec, mesh, name, x): + del mesh, x + if np.abs(dim) >= len(cur_spec): + msg = f"Cannot shard_dim({axis}, {dim}): name={name} cur_spec={cur_spec}" + if ignore_ndim_error: + logging.warning(msg) + return cur_spec + else: + raise ValueError(msg) + pos_dim = dim + if pos_dim < 0: + pos_dim += len(cur_spec) + if cur_spec[pos_dim] is not None: + raise ValueError( + f"Already sharded: shard_dim({axis}, {dim}):" + f" name={name} cur_spec={cur_spec}" + ) + new_spec = cur_spec[:pos_dim] + (axis,) + cur_spec[pos_dim + 1 :] + return new_spec + + return _update_spec diff --git a/big_vision/tools/download_tfds_datasets.py b/big_vision/tools/download_tfds_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..b64c33d51a7cb8df063d15e828e30b07007ff6b0 --- /dev/null +++ b/big_vision/tools/download_tfds_datasets.py @@ -0,0 +1,44 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download and prepare TFDS datasets for the big_vision codebase. + +This python script covers cifar10, cifar100, oxford_iiit_pet +and oxford_flowers10. + +If you want to integrate other public or custom datasets, please follow: +https://www.tensorflow.org/datasets/catalog/overview +""" + +from absl import app +import tensorflow_datasets as tfds + + +def main(argv): + if len(argv) > 1 and "download_tfds_datasets.py" in argv[0]: + datasets = argv[1:] + else: + datasets = [ + "cifar10", + "cifar100", + "oxford_iiit_pet", + "oxford_flowers102", + "imagenet_v2", + ] + for d in datasets: + tfds.load(name=d, download=True) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/tools/eval_only.py b/big_vision/tools/eval_only.py new file mode 100644 index 0000000000000000000000000000000000000000..abdde4a6c0aa656a2e8ec76ce645982a2a6723b3 --- /dev/null +++ b/big_vision/tools/eval_only.py @@ -0,0 +1,146 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script that loads a model and only runs evaluators.""" + +from functools import partial +import importlib + +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.utils as u +from clu import parameter_overview +import flax +import flax.jax_utils as flax_utils +import jax +import jax.numpy as jnp +from ml_collections import config_flags +from tensorflow.io import gfile + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() + + +def main(argv): + del argv + + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info("Workdir: %s", workdir) + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image"]): + importlib.import_module(f"big_vision.pp.{m}") + + # These functions do more stuff internally, for OSS release we mock them by + # trivial alternatives in order to minize disruptions in the code. + xid, wid = -1, -1 + def write_note(note): + if jax.process_index() == 0: + logging.info("NOTE: %s", note) + + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + u.chrono.inform(measure=mw.measure, write_note=write_note) + + write_note(f"Initializing {config.model_name} model...") + assert config.get("model.reinit") is None, ( + "I don't think you want any part of the model to be re-initialized.") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model_kw = dict(config.get("model", {})) + if "num_classes" in config: # Make it work for regular + image_text. + model_kw["num_classes"] = config.num_classes + model = model_mod.Model(**model_kw) + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + @partial(jax.jit, backend="cpu") + def init(rng): + input_shapes = config.get("init_shapes", [(1, 224, 224, 3)]) + input_types = config.get("init_types", [jnp.float32] * len(input_shapes)) + dummy_inputs = [jnp.zeros(s, t) for s, t in zip(input_shapes, input_types)] + things = flax.core.unfreeze(model.init(rng, *dummy_inputs)) + return things.get("params", {}) + + with u.chrono.log_timing("z/secs/init"): + params_cpu = init(jax.random.PRNGKey(42)) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview(params_cpu, msg="init params") + num_params = sum(p.size for p in jax.tree.leaves(params_cpu)) + mw.measure("num_params", num_params) + + # The use-case for not loading an init is testing and debugging. + if config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + params_cpu = model_mod.load( + params_cpu, config.model_init, config.get("model"), + **config.get("model_load", {})) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview(params_cpu, msg="loaded params") + + write_note("Replicating...") + params_repl = flax_utils.replicate(params_cpu) + + def predict_fn(params, *a, **kw): + return model.apply({"params": params}, *a, **kw) + + evaluators = eval_common.from_config( + config, {"predict": predict_fn, "model": model}, + lambda s: write_note(f"Initializing evaluator: {s}..."), + lambda key, cfg: 1, # Ignore log_steps, always run. + ) + + # Allow running for multiple steps can be useful for couple cases: + # 1. non-deterministic evaluators + # 2. warmup when timing evaluators (eg compile cache etc). + for s in range(config.get("eval_repeats", 1)): + mw.step_start(s) + for (name, evaluator, _, prefix) in evaluators: + write_note(f"{name} evaluation step {s}...") + with u.profile(name, noop=name in config.get("no_profile", [])): + with u.chrono.log_timing(f"z/secs/eval/{name}"): + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) + u.sync() # sync barrier to get correct measurements + u.chrono.flush_timings() + mw.step_end() + + write_note("Done!") + mw.close() + + # Make sure all hosts stay up until the end of main. + u.sync() + + if workdir and flags.FLAGS.cleanup and jax.process_index() == 0: + gfile.rmtree(workdir) + try: # Only need this on the last work-unit, if already empty. + gfile.remove(os.path.join(workdir, "..")) + except tf.errors.OpError: + pass + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/tools/lit_demo/README.md b/big_vision/tools/lit_demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/tools/lit_demo/build.js b/big_vision/tools/lit_demo/build.js new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/tools/lit_demo/package.json b/big_vision/tools/lit_demo/package.json new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/big_vision/train.py b/big_vision/train.py new file mode 100644 index 0000000000000000000000000000000000000000..51f49cc75c59f27e6393c73b400420e2c89da9e0 --- /dev/null +++ b/big_vision/train.py @@ -0,0 +1,517 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training loop example. + +This is a basic variant of a training loop, good starting point for fancy ones. +""" +# pylint: disable=consider-using-from-import +# pylint: disable=logging-fstring-interpolation + +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.sharding as bv_sharding +import big_vision.utils as u +from clu import parameter_overview +import flax.linen as nn +import jax +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import serialization as array_serial +from jax.experimental.shard_map import shard_map +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() +# Transfer guard will fail the program whenever that data between a host and +# a device is transferred implicitly. This often catches subtle bugs that +# cause slowdowns and memory fragmentation. Explicit transfers are done +# with jax.device_put and jax.device_get. +jax.config.update("jax_transfer_guard", "disallow") +# Fixes design flaw in jax.random that may cause unnecessary d2d comms. +jax.config.update("jax_threefry_partitionable", True) + + +NamedSharding = jax.sharding.NamedSharding +P = jax.sharding.PartitionSpec + + +def main(argv): + del argv + + # This is needed on multihost systems, but crashes on non-TPU single-host. + if os.environ.get("BV_JAX_INIT"): + jax.distributed.initialize() + + # Make sure TF does not touch GPUs. + tf.config.set_visible_devices([], "GPU") + + config = flags.FLAGS.config + +################################################################################ +# # +# Set up logging # +# # +################################################################################ + + # Set up work directory and print welcome message. + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + logging.info(f"The config:\n{config}") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.bv") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool(1) + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): + importlib.import_module(f"big_vision.pp.{m}") + + # Setup up logging and experiment manager. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + + # Allow for things like timings as early as possible! + u.chrono.inform(measure=mw.measure, write_note=write_note) + +################################################################################ +# # +# Set up Mesh # +# # +################################################################################ + + # We rely on jax mesh_utils to organize devices, such that communication + # speed is the fastest for the last dimension, second fastest for the + # penultimate dimension, etc. + config_mesh = config.get("mesh", [("data", jax.device_count())]) + + # Sharding rules with default + sharding_rules = config.get("sharding_rules", [("act_batch", "data")]) + + write_note("Creating device mesh...") + mesh = u.create_device_mesh( + config_mesh, + allow_split_physical_axes=config.get("mesh_allow_split_physical_axes", + False)) + repl_sharding = jax.sharding.NamedSharding(mesh, P()) + + # Consistent device order is important to ensure correctness of various train + # loop components, such as input pipeline, update step, evaluators. The + # order presribed by the `devices_flat` variable should be used throughout + # the program. + devices_flat = mesh.devices.flatten() + +################################################################################ +# # +# Input Pipeline # +# # +################################################################################ + + write_note("Initializing train dataset...") + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + train_ds, ntrain_img = input_pipeline.training(config.input) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + # Start input pipeline as early as possible. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_global(train_ds, devices_flat, n_prefetch) + +################################################################################ +# # +# Create Model & Optimizer # +# # +################################################################################ + + write_note("Creating model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model( + num_classes=config.num_classes, **config.get("model", {})) + + def init(rng): + batch = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype), + train_ds.element_spec) + params = model.init(rng, batch["image"])["params"] + + # Set bias in the head to a low value, such that loss is small initially. + if "init_head_bias" in config: + params["head"]["bias"] = jnp.full_like(params["head"]["bias"], + config["init_head_bias"]) + + return params + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0))) + + write_note("Inferring parameter shapes...") + rng, rng_init = jax.random.split(rng) + params_shape = jax.eval_shape(init, rng_init) + + write_note("Inferring optimizer state shapes...") + tx, sched_fns = bv_optax.make(config, nn.unbox(params_shape), sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + opt_shape = jax.eval_shape(tx.init, params_shape) + # We jit this, such that the arrays are created on the CPU, not device[0]. + sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns] + + if jax.process_index() == 0: + num_params = sum(np.prod(p.shape) for p in jax.tree.leaves(params_shape)) + mw.measure("num_params", num_params) + +################################################################################ +# # +# Shard & Transfer # +# # +################################################################################ + + write_note("Inferring shardings...") + train_state_shape = {"params": params_shape, "opt": opt_shape} + + strategy = config.get("sharding_strategy", [(".*", "replicate")]) + with nn.logical_axis_rules(sharding_rules): + train_state_sharding = bv_sharding.infer_sharding( + train_state_shape, strategy=strategy, mesh=mesh) + + write_note("Transferring train_state to devices...") + # RNG is always replicated + rng_init = u.reshard(rng_init, repl_sharding) + + # Parameters and the optimizer are now global (distributed) jax arrays. + params = jax.jit(init, out_shardings=train_state_sharding["params"])(rng_init) + opt = jax.jit(tx.init, out_shardings=train_state_sharding["opt"])(params) + + rng, rng_loop = jax.random.split(rng, 2) + rng_loop = u.reshard(rng_loop, repl_sharding) + del rng # not used anymore, so delete it. + + # At this point we have everything we need to form a train state. It contains + # all the parameters that are passed and updated by the main training step. + # From here on, we have no need for Flax AxisMetadata (such as partitioning). + train_state = nn.unbox({"params": params, "opt": opt}) + del params, opt # Delete to avoid memory leak or accidental reuse. + + write_note("Logging parameter overview...") + parameter_overview.log_parameter_overview( + train_state["params"], msg="Init params", + include_stats="global", jax_logging_process=0) + +################################################################################ +# # +# Update Step # +# # +################################################################################ + + @functools.partial( + jax.jit, + donate_argnums=(0,), + out_shardings=(train_state_sharding, repl_sharding)) + def update_fn(train_state, rng, batch): + """Update step.""" + + images, labels = batch["image"], batch["labels"] + + step_count = bv_optax.get_count(train_state["opt"], jittable=True) + rng = jax.random.fold_in(rng, step_count) + + if config.get("mixup") and config.mixup.p: + # The shard_map below makes mixup run on every device independently and + # thus avoids unnecessary communication. + sharded_mixup_fn = shard_map( + u.get_mixup(rng, config.mixup.p), + mesh=jax.sharding.Mesh(devices_flat, ("data",)), + in_specs=P("data"), out_specs=(P(), P("data"), P("data"))) + rng, (images, labels), _ = sharded_mixup_fn(images, labels) + + # Get device-specific loss rng. + rng, rng_model = jax.random.split(rng, 2) + + def loss_fn(params): + logits, _ = model.apply( + {"params": params}, images, + train=True, rngs={"dropout": rng_model}) + return getattr(u, config.get("loss", "sigmoid_xent"))( + logits=logits, labels=labels) + + params, opt = train_state["params"], train_state["opt"] + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + measurements = {"training_loss": loss} + gs = jax.tree.leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.sum(g * g) for g in gs])) + ps = jax.tree.leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.sum(p * p) for p in ps])) + us = jax.tree.leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.sum(u * u) for u in us])) + + return {"params": params, "opt": opt}, measurements + +################################################################################ +# # +# Load Checkpoint # +# # +################################################################################ + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + + ckpt_mngr = None + if save_ckpt_path or resume_ckpt_path: + ckpt_mngr = array_serial.GlobalAsyncCheckpointManager() + + if resume_ckpt_path: + write_note(f"Resuming training from checkpoint {resume_ckpt_path}...") + jax.tree.map(lambda x: x.delete(), train_state) + del train_state + shardings = { + **train_state_sharding, + "chrono": jax.tree.map(lambda _: repl_sharding, + u.chrono.save()), + } + loaded = u.load_checkpoint_ts( + resume_ckpt_path, tree=shardings, shardings=shardings) + train_state = {key: loaded[key] for key in train_state_sharding.keys()} + + u.chrono.load(jax.device_get(loaded["chrono"])) + del loaded + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + # TODO: when updating the `load` API soon, do pass and request the + # full `train_state` from it. Examples where useful: VQVAE, BN. + train_state["params"] = model_mod.load( + train_state["params"], config.model_init, config.get("model"), + **config.get("model_load", {})) + + # load has the freedom to return params not correctly sharded. Think of for + # example ViT resampling position embedings on CPU as numpy arrays. + train_state["params"] = u.reshard( + train_state["params"], train_state_sharding["params"]) + + parameter_overview.log_parameter_overview( + train_state["params"], msg="restored params", + include_stats="global", jax_logging_process=0) + + +################################################################################ +# # +# Setup Evals # +# # +################################################################################ + + # We do not jit/pmap this function, because it is passed to evaluator that + # does it later. We output as many intermediate tensors as possible for + # maximal flexibility. Later `jit` will prune out things that are not needed. + def eval_logits_fn(train_state, batch): + logits, out = model.apply({"params": train_state["params"]}, batch["image"]) + return logits, out + + def eval_loss_fn(train_state, batch): + logits, _ = model.apply({"params": train_state["params"]}, batch["image"]) + loss_fn = getattr(u, config.get("loss", "sigmoid_xent")) + return { + "loss": loss_fn(logits=logits, labels=batch["labels"], reduction=False) + } + + eval_fns = { + "predict": eval_logits_fn, + "loss": eval_loss_fn, + } + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, eval_fns, + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + devices_flat, + ) + + # At this point we need to know the current step to see whether to run evals. + write_note("Inferring the first step number...") + first_step_device = bv_optax.get_count(train_state["opt"], jittable=True) + first_step = int(jax.device_get(first_step_device)) + u.chrono.inform(first_step=first_step) + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + write_note("Running initial or final evals...") + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", value) + +################################################################################ +# # +# Train Loop # +# # +################################################################################ + + prof = None # Keeps track of start/stop of profiler state. + + write_note("Starting training loop, compiling the first step...") + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + with mesh, nn.logical_axis_rules(sharding_rules): + train_state, measurements = update_fn(train_state, rng_loop, batch) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", + sched_fn_cpu(u.put_cpu(step - 1))) + measurements = jax.device_get(measurements) + for name, value in measurements.items(): + mw.measure(name, value) + u.chrono.tick(step) + for k in ("training_loss", "l2_grads", "l2_updates", "l2_params"): + if not np.isfinite(measurements.get(k, 0.0)): + raise RuntimeError(f"{k} became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + keep_ckpt_steps = get_steps("keep_ckpt", None) or total_steps + if save_ckpt_path and ( + (keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False)) + or u.itstime(step, get_steps("ckpt", None), total_steps, first=True) + ): + u.chrono.pause(wait_for=train_state) + + # Copy because we add extra stuff to the checkpoint. + ckpt = {**train_state} + + # To save chrono state correctly and safely in a multihost setup, we + # broadcast the state to all hosts and convert it to a global array. + with jax.transfer_guard("allow"): + chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save()) + chrono_shardings = jax.tree.map(lambda _: repl_sharding, chrono_ckpt) + ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)} + + u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=train_state) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules(sharding_rules): + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", jax.device_get(value)) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + if ckpt_mngr: + ckpt_mngr.wait_until_finished() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/utils.py b/big_vision/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a954300f9ac7e4ccce27c38eb1367d3233451532 --- /dev/null +++ b/big_vision/utils.py @@ -0,0 +1,1478 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils very specific to this project, not generic.""" + +import collections +import contextlib +import dataclasses +import functools +import io +import json +import multiprocessing +import multiprocessing.pool +import os +import re +import sys +import time +from typing import Mapping + +from absl import flags +from absl import logging +from big_vision.pp import registry as pp_registry +import einops +import flax +import flax.jax_utils as flax_utils +import jax +from jax.experimental import mesh_utils +from jax.experimental.array_serialization import serialization as array_serial +import jax.numpy as jnp +import ml_collections as mlc +import numpy as np + +import tensorflow.io.gfile as gfile # pylint: disable=consider-using-from-import + + +Registry = pp_registry.Registry + + +# pylint: disable=logging-fstring-interpolation + + +def pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=()): + """Wraps a function with code that pads, shards, then un-shards, un-pads. + + Args: + wrapped: the function to be wrapped. Signature is `params, *args, *kwargs`. + static_argnums: indices of arguments to `wrapped` that should _not_ be + padded and sharded, but instead be forwarded as-is. The default is (0,) + because by far the most common use-case is to pass `params` first. + static_argnames: names of kwargs to `wrapped` that should _not_ be padded + and sharded, but instead be forwarded as-is. + + Returns: + A new function that pads and shards its arguments before passing them to + the wrapped function, and un-shards and un-pads the returned pytree. + + This is useful for calling a pmap'ed function with inputs that aren't + divisible by the number of devices. A typical use is: + @pad_shard_unpad + @jax.pmap + def forward(params, x): ... + + Notes: + The padding is done in host-memory before being passed to the function, and + the values returned by the function are transferred back to host memory. + + The returned function is augmented with a new keyword-only argument + `min_device_batch` that, if specified, forces padding inputs to at least + this size per device. This can be useful to avoid recompiles for the last + batch and reduce memory fragmentation. + """ + + def pad_shard_unpad_wrapper(*args, min_device_batch=None, **kw): + d = jax.local_device_count() # d = devices, b = batch + + # Find the batch-sizes of all non-static arguments. + def get_bs(x): + batch_sizes = jax.tree.map(lambda y: y.shape[0], x) + return jax.tree.flatten(batch_sizes)[0] + + bs_a = [get_bs(a) for i, a in enumerate(args) if i not in static_argnums] + bs_kw = [get_bs(v) for k, v in kw.items() if k not in static_argnames] + bs = set([n for b in (bs_a + bs_kw) for n in b]) + assert len(bs) == 1, f"Inconsistent batch-sizes: {bs}" + b = bs.pop() + + def pad(x): + _, *shape = x.shape + db, rest = divmod(b, d) + if rest: + x = np.concatenate([x, np.zeros((d - rest, *shape), x.dtype)], axis=0) + db += 1 + if min_device_batch and db < min_device_batch: + x = np.concatenate( + [x, np.zeros((d * (min_device_batch - db), *shape), x.dtype)]) + db = min_device_batch + return x.reshape(d, db, *shape) + + def maybe_pad(x, actually_pad=True): + if not actually_pad: return x # For call-site convenience below. + return jax.tree.map(pad, x) + + args = [maybe_pad(a, i not in static_argnums) for i, a in enumerate(args)] + kw = {k: maybe_pad(v, k not in static_argnames) for k, v in kw.items()} + out = wrapped(*args, **kw) + + def unpad(x): + # Transfer back before cutting, to reduce on-device shape diversity. + return einops.rearrange(jax.device_get(x), "d b ... -> (d b) ...")[:b] + return jax.tree.map(unpad, out) + + return pad_shard_unpad_wrapper + + +def onehot(labels, num_classes, on_value=1.0, off_value=0.0): + x = (labels[..., None] == jnp.arange(num_classes)[None]) + x = jax.lax.select(x, jnp.full(x.shape, on_value), + jnp.full(x.shape, off_value)) + return x.astype(jnp.float32) + + +def npload(fname): + """Loads `fname` and returns an np.ndarray or dict thereof.""" + # Load the data; use local paths directly if possible: + if os.path.exists(fname): + loaded = np.load(fname, allow_pickle=False) + else: + # For other (remote) paths go via gfile+BytesIO as np.load requires seeks. + with gfile.GFile(fname, "rb") as f: + data = f.read() + loaded = np.load(io.BytesIO(data), allow_pickle=False) + + # Support loading both single-array files (np.save) and zips (np.savez). + if isinstance(loaded, np.ndarray): + return loaded + else: + return dict(loaded) + + +def load_checkpoint_np(npz, tree=None): + """Loads a jax pytree from a npz file. + + Args: + npz: Either path to the checkpoint file (.npz), or a dict-like. + tree: deprecated, use None. + Bwd-compat for old format that only stored values: the pytree structure. + + Returns: + A pytree that is the checkpoint. + """ + if isinstance(npz, str): # If not already loaded, then load. + npz = npload(npz) + keys, values = zip(*list(npz.items())) + if tree: + checkpoint = tree.unflatten(values) + else: + checkpoint = recover_tree(keys, values) + return checkpoint + + +def load_params(ckpt, **kw): + """Loads the parameters of a big_vision checkpoint, both old or new format. + + Args: + ckpt: Path to the checkpoint (.npz, .ts) or dict-like. + **kw: forwarded to the underlying load function (_np or _ts). + + Returns: + A pytree that is the checkpoint, potentially sharded. + + Notes: + The `ckpt` string can contain an colon-separated "submodel" indicator, like + `img` in the example `/path/to/file.npz:img`. + This is used to load sub-parts of a model, for example the image load the + image encoder out of a two_tower (SigLIP) checkpoint, or distillation. + This way, ANY model that uses this function can load itself from a + checkpoint that contains multiple sub-models. + """ + key = None # Whether we want to extract only a sub-key of the model. + + if isinstance(ckpt, str): # Most common case of passing a checkpoint path. + # Potentially read out the sub-part to load from after the colon + # '/path/to/file:img/head' => '/path/to/file', 'img/head' + # 'gs://path/to/file' => 'gs://path/to/file', None + if match := re.match(r"^(.*?/.*?)(?::([\w/]+))?$", ckpt): + ckpt, key = match.groups() + else: + raise ValueError(f"Weird ckpt path: {ckpt} ; Maybe prepend ./ ?") + + # Use the checkpoint filename to detect when we're loading old-style .npz + # checkpoints, as opposed to new-style tensorstore checkpoint folders. + if ".npz" in ckpt: # Not a perfect heuristic, but good enough. + checkpoint = load_checkpoint_np(ckpt, **kw) + checkpoint = jax.tree.map(recover_dtype, checkpoint) + if "params" in checkpoint: + # Checkpoint with optax state (after (internal link)). + params = checkpoint["params"] + elif "opt" in checkpoint: + # Checkpoint with Flax optimizer. + params = checkpoint["opt"]["target"] + else: + # When open-sourcing, we often shared only the params directly. + params = checkpoint + else: + # Here we're now loading new-style tensorstore checkpoints. + # We can be a more efficient and load params and `key` only right away. + regex = f"params/{key}($|/.*)" if key else "params/.*" + assert "regex" not in kw, "For a custom regex, use tsload directly." + kw["regex"] = regex + checkpoint = load_checkpoint_ts(ckpt, **kw) + params = checkpoint["params"] + + if key is not None: + params = tree_get(params, key) + + return params + + +def prefetch_scalar(it, nprefetch=1, devices=None): + n_loc_dev = len(devices) if devices else jax.local_device_count() + repl_iter = (np.ones(n_loc_dev) * i for i in it) + return flax_utils.prefetch_to_device(repl_iter, nprefetch, devices) + + +def sigmoid_xent(*, logits, labels, reduction=True): + # NOTE: This implementation is stable, see these two: + # (internal link) + # https://github.com/google/jax/issues/2140 + log_p = jax.nn.log_sigmoid(logits) + log_not_p = jax.nn.log_sigmoid(-logits) + nll = -jnp.sum(labels * log_p + (1. - labels) * log_not_p, axis=-1) + return jnp.mean(nll) if reduction else nll + + +def bidirectional_contrastive_loss(zimg, ztxt, t, mask=None, reduction=False): + """Bidirectional contrastive loss (e.g. for contrastive trainer/evaluator).""" + # BF.FB = BB + logits = jnp.dot(zimg, ztxt.T) * t + + if mask is not None: + # Set to negative infinity where mask = 0. Masked examples will disappear + # under softmax, and be ignored by ncorrect (NINF will never win argmax). + exclude = jnp.logical_not(mask) # Now 1 if we don't want to keep. + exclude = jnp.logical_or(exclude[:, None], exclude[None, :]) + logits = jnp.where(exclude, -jnp.inf, logits) + + # Note: assumed t is in a good range e.g. already passed through exp/softplus. + l1 = -jnp.diag(jax.nn.log_softmax(logits, axis=1)) # NLL img->txt + l2 = -jnp.diag(jax.nn.log_softmax(logits, axis=0)) # NLL txt->img + l = 0.5 * (l1 + l2) + + if mask is not None: + l = jnp.where(mask, l, 0) + + redux = jnp.mean if reduction else lambda x: x + if reduction and mask is not None: + redux = lambda x: jnp.sum(x * mask) / (jnp.sum(mask) + 1e-8) + + # Also return extra measurements. + return redux(l), { + "ncorrect": redux(jnp.argmax(logits, axis=1) == jnp.arange(len(logits))), + } + + +def softmax_xent(*, logits, labels, reduction=True, kl=False, axis=-1): + log_p = jax.nn.log_softmax(logits, axis=axis) + nll = -jnp.sum(labels * log_p, axis=axis) + if kl: + nll += jnp.sum(labels * jnp.log(jnp.clip(labels, 1e-8)), axis=axis) + return jnp.mean(nll) if reduction else nll + + +def weighted_softmax_xent(*, + logits, + labels, + reduction=True, + weights=None, + label_smoothing=0.0, + normalize=True): + """Compute weighted cross entropy. + + Args: + logits: [batch, length, num_classes] float array. + labels: categorical targets [batch, length] int array. + reduction: reduce across batch dim. + weights: None or array of shape [batch, length]. + label_smoothing: label smoothing constant, used to determine the on and off + values. + normalize: normalize each "sentence" loss by the number of tokens in it. + + Returns: + Tuple of scalar loss and batch normalizing factor. + """ + if logits.ndim != labels.ndim + 1: + raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" % + (str(logits.shape), str(labels.shape))) + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing + low_confidence = (1.0 - confidence) / (vocab_size - 1) + soft_targets = onehot( + labels, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = -jnp.sum(soft_targets * jax.nn.log_softmax(logits), axis=-1) + + normalizing_factor = labels.shape[1] + if weights is not None: + loss = loss * weights + normalizing_factor = jnp.clip(weights.sum(axis=1), 2e-38) + + loss = loss.sum(axis=1) + if normalize: + loss = loss / normalizing_factor + + return loss.mean() if reduction else loss + + +def accumulate_gradient(loss_and_grad_fn, params, images, labels, accum_steps): + """Accumulate gradient over multiple steps to save on memory.""" + # See (internal link) for details and experiments. + if accum_steps and accum_steps > 1: + assert images.shape[0] % accum_steps == 0, ( + f"Bad accum_steps {accum_steps} for batch size {images.shape[0]}") + step_size = images.shape[0] // accum_steps + l, g = loss_and_grad_fn(params, images[:step_size], labels[:step_size]) + def acc_grad_and_loss(i, l_and_g): + imgs = jax.lax.dynamic_slice(images, (i*step_size, 0, 0, 0), + (step_size,) + images.shape[1:]) + lbls = jax.lax.dynamic_slice(labels, (i*step_size, 0), + (step_size, labels.shape[1])) + li, gi = loss_and_grad_fn(params, imgs, lbls) + l, g = l_and_g + return (l + li, jax.tree.map(lambda x, y: x + y, g, gi)) + l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g)) + return jax.tree.map(lambda x: x / accum_steps, (l, g)) + else: + return loss_and_grad_fn(params, images, labels) + + +def itstime(step, every_n_steps, total_steps, host=None, last=True, first=True, + drop_close_to_last=0.25): + """Returns True if it's time to execute an action. + + Args: + step: the current step representing "now". + every_n_steps: the action should run every this many steps. + total_steps: the step number of the last step of training. + host: host number. If provided, only run if we are this process. + last: whether to run on the last step or not. + first: whether to run on the first step or not. + drop_close_to_last: if a step would run, but is this close (in terms of + fraction of every_n_step) to the last one, skip. + + Returns: + True if the action should be executed, False if not. + """ + + # This logic avoids running `itstime` "a few" steps before the last step. + # Canonical example: don't save checkpoint 2 steps before the last, and then + # at the last again; it's pointless and checkpoint timing will time out. + close_to_last = False + if drop_close_to_last and every_n_steps: + close_to_last = abs(step - total_steps) < drop_close_to_last * every_n_steps + + is_host = host is None or jax.process_index() == host + is_step = every_n_steps and (step % every_n_steps == 0) and not close_to_last + is_last = every_n_steps and step == total_steps + is_first = every_n_steps and step == 1 + return is_host and (is_step or (last and is_last) or (first and is_first)) + + +def checkpointing_timeout(writer, timeout): + # Make sure checkpoint writing is not a bottleneck + if writer is not None: + try: + # Note: `writer` is a multiprocessing.AsyncResult, and + # timeout is in seconds. + writer.get(timeout=timeout) + except multiprocessing.TimeoutError as e: + raise TimeoutError( + "Checkpoint writing seems to be a bottleneck. Make sure you do " + "not do something wrong, like writing checkpoints to a distant " + "cell. In a case you are OK with checkpoint writing being a " + "bottleneck, you can configure `ckpt_timeout` parameter") from e + + +def hms(s): + """Format time in hours/minutes/seconds.""" + if s < 60: + return f"{s:.0f}s" + m, s = divmod(s, 60) + if m < 60: + return f"{m:.0f}m{s:.0f}s" + h, m = divmod(m, 60) + if h < 25: + return f"{h:.0f}h{m:.0f}m" # Seconds intentionally omitted. + d, h = divmod(h, 24) + return f"{d:.0f}d{h:.0f}h{m:.0f}m" # Seconds intentionally omitted. + + +class Chrono: + """Measures time and reports progress, hyper-specific to our train loops. + + Some concepts: + 1. This differentiates between three "types" of time: + - training time: the time spent on actual training (fprop/bprop/update) + - program time: overall time the program runs, including all overheads + - pause time: the chronometer can be paused (eg during evals). + 2. This handles a "warmup": the first step is skipped for training time + purposes, as it includes significant compilation overheads, which distort + estimates. + 3. `accum`ulates (i.e. integrates) timings, and save/load them across + restarts. + """ + + def __init__(self): + self._timing_history = collections.defaultdict(list) + self._measure = None + self._write_note = None + + self.program_start_time = time.monotonic() + self.train_start_time = None + self.train_start_step = None # When we started timing (after warmup) + + self.prev_time = None + self.prev_step = None + + self.pause_start = None + self.paused_time = 0 + + self.total_steps = None + self.global_bs = None + self.steps_per_epoch = None + + self.warmup = 2 # How many calls to `tick` to skip. + self.load() # Inits accum integrators. + self.note = "Chrono n/a" + + def inform(self, *, first_step=None, total_steps=None, global_bs=None, + steps_per_epoch=None, measure=None, write_note=None): + """Provide some extra info that's only known later in the program.""" + # The pattern of `self.x = x or self.x` allows one to call `inform` various + # times with various subset of information (args), as they become available. + # Except for `first_step` which can be 0 so is a bit more verbose. + self.prev_step = first_step if first_step is not None else self.prev_step + self.total_steps = total_steps or self.total_steps + self.steps_per_epoch = steps_per_epoch or self.steps_per_epoch + self.global_bs = global_bs or self.global_bs + self._measure = measure or self._measure + self._write_note = write_note or self._write_note + if self.total_steps and self.prev_step is not None: + self.note = (f"Steps:{self.prev_step}/{self.total_steps} " + f"[{self.prev_step/self.total_steps:.1%}]") + + def tick(self, step, measure=None, write_note=None): + """A chronometer tick.""" + if step == self.prev_step: return # Can happen from evals for example. + + measure = measure or self._measure + write_note = write_note or self._write_note + + now = time.monotonic() + measure("uptime", now - self.program_start_time) + self.flush_timings() + + # We do always count examples, regardless of the timing-related warmup that + # happens a few lines below. + ds = step - self.prev_step # Steps between ticks + self.prev_step = step + self.accum_examples_seen += ds * self.global_bs + measure("examples_seen", self.accum_examples_seen) + measure("progress", step / self.total_steps) + if self.steps_per_epoch: + measure("epoch", step / self.steps_per_epoch) + + # We take the start as the second time `tick` is called, so we avoid + # measuring the overhead of compilation and don't include it in time + # estimates. + if self.warmup > 1: + self.warmup -= 1 + write_note(self.note) # This can help debugging. + return + if self.warmup == 1: + self.train_start_time = self.prev_time = now + self.train_start_step = step + self.accum_program_time += now - self.program_start_time + self.paused_time = 0 # Drop pauses that happened before timing starts. + self.warmup = 0 + write_note(self.note) # This can help debugging. + return + + # Measurement with micro-timings of current training steps speed. + # Time between ticks (ignoring pause) + dt = now - self.prev_time - self.paused_time + ncores = jax.device_count() # Global device count + measure("img/sec/core", self.global_bs * ds / dt / ncores) + + # Accumulate (integrate) times, good for plots. + self.accum_train_time += dt + self.accum_pause_time += self.paused_time + self.accum_program_time += dt + self.paused_time + + # Convert to, and log as, core hours. + core_hours = self.accum_train_time * ncores / 60 / 60 + devtype = jax.devices()[0].device_kind + measure(f"core_hours_{devtype}", core_hours) + measure("core_hours", core_hours) # For convenience as x-axis in sweeps. + + # Progress note with "global" full-program average timings + # (eg in program-time minus warmup) + dt = now - self.train_start_time # Time elapsed since end of warmup. + steps_timed = step - self.train_start_step + steps_todo = self.total_steps - step + self.note = f"Steps:{step}/{self.total_steps} [{step/self.total_steps:.1%}]" + self.note += f"\nWalltime:{hms(self.accum_program_time)}" + self.note += f" ({hms(self.accum_pause_time)} eval)" + self.note += f"\nETA:{hms(dt / steps_timed*steps_todo)}" + self.note += f"\nTotal train time:{hms(dt / steps_timed*self.total_steps)}" + write_note(self.note) + + log_memory(measure) + + self.prev_time = now + self.paused_time = 0 + + def pause(self, wait_for=()): + assert self.pause_start is None, "Don't pause twice." + jax.block_until_ready(wait_for) + self.pause_start = time.monotonic() + + def resume(self): + self.paused_time += time.monotonic() - self.pause_start + self.pause_start = None + + def save(self): + return dict( + accum_program_time=self.accum_program_time, + accum_train_time=self.accum_train_time, + accum_pause_time=self.accum_pause_time, + accum_examples_seen=self.accum_examples_seen, + ) + + def load(self, ckpt={}): # pylint: disable=dangerous-default-value + self.accum_program_time = float(ckpt.get("accum_program_time", 0.0)) + self.accum_train_time = float(ckpt.get("accum_train_time", 0.0)) + self.accum_pause_time = float(ckpt.get("accum_pause_time", 0.0)) + self.accum_examples_seen = int(ckpt.get("accum_examples_seen", 0)) + + @contextlib.contextmanager + def log_timing(self, name, *, noop=False): + """Use this when you time sth once per step and want instant flushing.""" + t0 = time.monotonic() + yield + dt = time.monotonic() - t0 + if not noop: + if self._measure: # So that timed things still work in colab. + self._measure(name, dt) + logging.info("TIMING[%s]: %s", name, dt) + logging.flush() + + @contextlib.contextmanager + def log_timing_avg(self, name, *, noop=False): + """Use this when you time sth multiple times per step (eg in a loop).""" + t0 = time.monotonic() + yield + dt = time.monotonic() - t0 + if not noop: + self._timing_history[name].append(dt) + logging.info("TIMING[%s]: avg %s current %s", + name, np.mean(self._timing_history[name]), dt) + logging.flush() + + def flush_timings(self): + assert self._measure is not None + for name, times in self._timing_history.items(): + self._measure(name, np.mean(times)) + self._timing_history.clear() + + +# Singleton to use from everywhere. https://stackoverflow.com/a/6760726/2366315 +chrono = Chrono() + + +def log_memory(measure): + """Log a bunch of memory-related measurements.""" + try: + import psutil + except ImportError: + psutil = None + + if psutil is not None: + # Note that total != available + used, see psutil docs. + vmem = psutil.virtual_memory() + measure("y/hostmem/total", vmem.total) + measure("y/hostmem/available", vmem.available) + measure("y/hostmem/used", vmem.used) + + # We show only device 0 and 1 to avoid spam. The reason to show two and not + # just one, if multiple are available, is because a frequent mistake is to + # create arrays on the default device, which is device 0. + for i, d in zip([0, 1], jax.local_devices()): + for k, v in (d.memory_stats() or {}).items(): + measure(f"y/devmem/dev{i}/{k}", v) + + +def _traverse_with_names(tree, with_inner_nodes=False): + """Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val).""" + if dataclasses.is_dataclass(tree): + tree = flax.serialization.to_state_dict(tree) + # Don't output the non-leaf nodes. If the optimizer doesn't have a state + # the tree leaves can be Nones which was interpreted as a leaf by this + # function but not by the other functions (like jax.tree.map). + if tree is None: + return + elif isinstance(tree, Mapping): + keys = sorted(tree.keys()) + for key in keys: + for path, v in _traverse_with_names(tree[key], with_inner_nodes): + yield (key + "/" + path).rstrip("/"), v + if with_inner_nodes: + yield "", tree + elif isinstance(tree, (list, tuple)): + for idx in range(len(tree)): + for path, v in _traverse_with_names(tree[idx], with_inner_nodes): + yield (str(idx) + "/" + path).rstrip("/"), v + if with_inner_nodes: + yield "", tree + else: + yield "", tree + + +def tree_flatten_with_names(tree): + """Populates tree_flatten with leaf names. + + This function populates output of tree_flatten with leaf names, using a + custom traversal that produces names is provided. The custom traversal does + NOT have to traverse tree in the same order as jax, as we take care of + automatically aligning jax' and custom traversals. + + Args: + tree: python tree. + + Returns: + A list of values with names: [(name, value), ...] + """ + vals, tree_def = jax.tree.flatten(tree) + + # "Fake" token tree that is use to track jax internal tree traversal and + # adjust our custom tree traversal to be compatible with it. + tokens = range(len(vals)) + token_tree = tree_def.unflatten(tokens) + val_names, perm = zip(*_traverse_with_names(token_tree)) + inv_perm = np.argsort(perm) + + # Custom traverasal should visit the same number of leaves. + assert len(val_names) == len(vals) + + return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def + + +def tree_unflatten(names_and_vals): + """Reverses `tree_flatten_with_names(tree)[0]`.""" + return recover_tree(*zip(*names_and_vals)) + + +def tree_map_with_names(f, tree, *rest): + """Like jax.tree.map but with a filter on the leaf path name. + + Args: + f: A function with first parameter `name` (path-like "a/b/c") and remaining + parameters values of `tree` and `*rest` corresponding to the given `name` + Should return a new value for parameter `name`. + tree: The tree of parameters `f` should be applied to. + *rest: more trees of the exact same structure. + + Returns: + A tree identical in structure to `tree` and `*rest` but with the leaves the + result of calling `f` on corresponding name/leaves in `tree` and `*rest`. + """ + names_and_vals, tree_def = tree_flatten_with_names(tree) + names, vals = zip(*names_and_vals) + rest_vals = [list(zip(*tree_flatten_with_names(t)[0]))[1] for t in rest] + vals = [f(*name_and_vals) for name_and_vals in zip(names, vals, *rest_vals)] + return tree_def.unflatten(vals) + + +def tree_map_with_regex(f, tree, regex_rules, not_f=lambda x: x, name=None): + """Apply jax-style tree_map based on regex rules. + + Args: + f: a function that is being applied to every variable. + tree: jax tree of arrays. + regex_rules: a list of tuples `(pattern, args)`, where `pattern` is a regex + which used for variable matching and `args` are positional arguments + passed to `f`. If some variable is not matched, we apply `not_f` transform + which is id by default. If multiple patterns match, then only the first + rule is applied. + not_f: optional function which is applied to variables that do not match any + pattern. + name: a name of transform for logging purposes. + + Returns: + a tree, transformed by `f` according to the given rules. + """ + def _f(vname, v): + for pattern, arg in regex_rules: + if re.fullmatch(pattern, vname): + if name and jax.process_index() == 0: + logging.info("Applying %s to %s with %s due to `%s`", + name, vname, arg, pattern) + return f(v, arg) + return not_f(v) + return tree_map_with_names(_f, tree) + + +def tree_get(tree, name): + """Get an entry of pytree by flattened key name, eg a/b/c, with nice error. + + Args: + tree: the pytree to be queried. + name: the path to extract from the tree, see below for examples. + + Returns: + A few examples: + tree = {'a': 1, 'b': {'c': 2, 'd': 3}} + tree_get(tree, 'a') == 1 + tree_get(tree, 'b/c') == 2 + tree_get(tree, 'b') == {'c': 2, 'd': 3} + """ + flattened = dict(_traverse_with_names(tree, with_inner_nodes=True)) + try: + return flattened[name] + except KeyError as e: + class Msg(str): # Reason: https://stackoverflow.com/a/70114007/2366315 + def __repr__(self): + return str(self) + msg = "\n".join([name, "Available keys:", *flattened, ""]) + # Turn into configdict to use its "did you mean?" error message! + msg = mlc.ConfigDict(flattened)._generate_did_you_mean_message(name, msg) # pylint: disable=protected-access + raise KeyError(Msg(msg)) from e + + +def tree_replace(tree, replacements): + """Renames/removes (nested) keys. + + Example usage: + + tree = {'a': {'b': 2, 'c': 3}, 'c': 4} + replacements = { + 'a/b': 'a/b/x', # replaces 'a/b' with 'a/b/x' + '.*c': 'C', # replaces 'c' with 'C' ('a/c' is removed) + 'C': 'D', # replaces 'C' (which was 'c') with 'D' + '.*/c': None, # removes 'a/c' + } + tree2 = rename_remove(tree, replacements) + assert tree2 == {'D': 4, 'a': {'b': {'x': 2}}} + + Args: + tree: A nested dictionary. + replacements: Rules specifying `regex` as keys and `replacement` as values + to be used with `m = re.match(regex, key)` and `m.expand(replacement)` + for every `key` independently. + + Note that: + 1. If any rule matches with `replacement=None`, then the key is removed. + 2. The rules are applied in order. It's possible to have multiple + transformations on a single key. + + Returns: + Updated `tree` according to rules defined in `replacements`. + """ + replacements = { + re.compile(kk): vv for kk, vv in replacements.items() + } + + def rename(k): + for kk, vv in replacements.items(): + m = kk.match(k) + if m: + k = k[:m.start()] + m.expand(vv) + k[m.end():] + return k + + def should_remove(k): + return any(vv is None and kk.match(k) for kk, vv in replacements.items()) + + names_and_vals, _ = tree_flatten_with_names(tree) + names_and_vals = [ + (rename(k), v) for k, v in names_and_vals if not should_remove(k) + ] + return tree_unflatten(names_and_vals) + + +def tree_compare(tree1, tree2): + """Returns `(tree1_only, tree2_only, dtype_shape_mismatch)`.""" + tree1 = flax.traverse_util.flatten_dict(tree1, sep="/") + tree2 = flax.traverse_util.flatten_dict(tree2, sep="/") + return set(tree1) - set(tree2), set(tree2) - set(tree1), { + k: [(v.dtype, v.shape), (tree2[k].dtype, tree2[k].shape)] + for k, v in tree1.items() + if k in tree2 and (v.dtype != tree2[k].dtype or v.shape != tree2[k].shape) + } + + +def tree_filter(tree, mask): + """Returns nested dict structure with only a subset of children.""" + # TODO: The code below only works for nested-dict and only when they + # have same structure. Consider relax this. + if not isinstance(tree, dict): + assert isinstance(mask, bool), f"Mask leaves must be boolean! {mask}" + return tree + assert sorted(tree.keys()) == sorted(mask.keys()), ( + f"Keys in tree and mask are not equal! {tree.keys()} != {mask.keys()}") + return {k: tree_filter(v, mask[k]) for k, v in tree.items() + if mask[k] is not False} + + +def recover_dtype(a): + """Numpy's `save` stores bfloat16 type as "void" type, so we recover it.""" + if hasattr(a, "dtype") and a.dtype.type is np.void: + assert a.itemsize == 2, "Unknown dtype!" + return a.view(jax.numpy.bfloat16) + else: + return a + + +def recover_tree(keys, values): + """Recovers a tree as a nested dict from flat names and values. + + This function is useful to analyze checkpoints that are saved by our programs + without need to access the exact source code of the experiment. In particular, + it can be used to extract an reuse various subtrees of the scheckpoint, e.g. + subtree of parameters. + + Args: + keys: a list of keys, where '/' is used as separator between nodes. + values: a list of leaf values. + + Returns: + A nested tree-like dict. + """ + tree = {} + sub_trees = collections.defaultdict(list) + for k, v in zip(keys, values): + if "/" not in k: + tree[k] = v + else: + k_left, k_right = k.split("/", 1) + sub_trees[k_left].append((k_right, v)) + for k, kv_pairs in sub_trees.items(): + k_subtree, v_subtree = zip(*kv_pairs) + tree[k] = recover_tree(k_subtree, v_subtree) + return tree + + +def tssave(mngr, pytree, path, on_commit=lambda *_, **__: None): + """Save pytree using jax tensorstore-based checkpoint manager. + + NOTE: When overwriting an existing checkpoint with a different pytree, the + result is, counterintuitively, the union of both, not only the new one. + + Args: + mngr: An instance of GlobalAsyncCheckpointManager. + pytree: What to store; any pytree of arrays. + path: Where to save the pytree. Creates subfolders as needed. + on_commit: A callback when writing is done, see `mngr.serialize`. + """ + names, vals = zip(*tree_flatten_with_names(pytree)[0]) + + for name in names: + if "~" in name: + raise ValueError(f"Symbol '~' is not allowed in names. Found in {name}.") + + gfile.makedirs(path) + with jax.transfer_guard("allow"): + names = [name.replace("/", "~") for name in names] + mngr.serialize_with_paths( + list(vals), [os.path.join(path, name) for name in names], + on_commit_callback=functools.partial(on_commit, array_names=names)) + + +def save_checkpoint_ts(mngr, checkpoint, path, step, keep=True): + """Preemption-safe saving of checkpoints using tssave.""" + # The tensorstore checkpoint format is a folder with (potentially) many files. + # On some file-systems, operations on these (copy, rename, delete) are slow, + # so we implement a flow that's both robust to pre-emptions/crashes during + # checkpointing and makes minimal use of these slow operations. + + # The logic goes as follows. It's infaillible :) + # (...if file move is atomic, which it is.) + # We always write the current checkpoint to a new folder, which contains the + # step number in its name. If we don't need to keep it indefinitely, we append + # "-tmp" to its name. + # After writing the next checkpoint, we remove the previous one if it had + # "-tmp" in its name. + # We also have a -LAST file that contains a pointer to the latest complete + # checkpoint. File operations are cheap to make atomic, that's why. + + def _on_commit_callback(array_names): # Runs after writing ckpt is done. + with gfile.GFile(f"{path}-CUR", "w") as f: + f.write(curr) + + last = "" + if gfile.exists(f"{path}-LAST"): + with gfile.GFile(f"{path}-LAST", "r") as f: + last = f.read().strip() + + gfile.rename(f"{path}-CUR", f"{path}-LAST", overwrite=True) + + if last.endswith("-tmp"): + # If pre-emption happens here, some old checkpoints may not be deleted. + multiprocessing.pool.ThreadPool().map( + gfile.rmtree, + [f"{path}-{last}/{name}" for name in array_names]) + gfile.rmtree(f"{path}-{last}") + + # NOTE: The jax checkpoint manager automatically waits for the previous save + # to be finished before writing again, so we don't need to do it here. + + # Always write to path with step number in it. + curr = f"{step:09d}{'-tmp' if not keep else ''}" + tssave(mngr, checkpoint, f"{path}-{curr}", _on_commit_callback) + + +def load_checkpoint_ts(path, **tsload_kw): + """Loads a big_vision checkpoint saved by `save_checkpoint_ts`.""" + to_load = path + + try: + # When passing a general path (not a specific step), get the last available. + with gfile.GFile(f"{path}-LAST", "r") as f: + to_load = f"{path}-{f.read().strip()}" + except Exception: # Differs based on backend, so blanket catch. pylint:disable=broad-exception-caught + pass + + return tsload(to_load, **tsload_kw) + + +def tsload(path, *, tree=None, shardings=None, regex=None): + """Loads tensorstore-based array-tree from disk. + + If `tree` argument is provided, then array names to load and target structure + is derived from the tree. If `tree` is None, then array names to load are + derived from array filenames on the disk, and, optionally, `regex` is applied + to filter these names. The`tree` argument is then automatically derived from + array names with `recover_tree` util. + + Arrays are loaded to CPU/TPU/GPU memory as specified by the `shardings` + argument, which is a pytree of CPU/TPU/GPU shardings (can be mixed within a + single pytree). `shardings` should a prefix tree of the `tree` argument. We + automatically broadcast `shardings` to a full `tree`. For example, a user can + specify `shardings=jax.sharding.SingleDeviceSharing(jax.devices('cpu')[0])`, + which will be broadcasted to a full tree. + + Args: + path: a directory where the checkpoint arrays are stored. + tree: a target pytree, which defines array names to load and the target tree + structure. If tree is None, then `tree` is inferred from the names of + arrays stored on the disk. + shardings: a prefix pytree (with respect to `tree`) of the target shardings. + regex: regex to filter array names from the disk, if `tree` is not provided. + + Returns: + A pytree of loaded arrays that has the same structure as `shardings` arg. + """ + if (tree is not None) and (regex is not None): + raise ValueError("If tree is specified, regex filtering is not allowed.") + + if tree is None: + # Some file-systems (gs://) list folders with a trailing /, get rid of it. + path_names = set([p.rstrip("/").replace("~", "/") + for p in gfile.listdir(path)]) + regex = re.compile(regex) if regex is not None else re.compile(".*") + path_names = [p for p in path_names if regex.match(p)] + tree = recover_tree(path_names, [0] * len(path_names)) + + names_and_vals, tree_def = tree_flatten_with_names(tree) + names_to_load, _ = zip(*names_and_vals) + + if shardings is None: + shardings = jax.sharding.SingleDeviceSharding( + jax.local_devices(backend="cpu")[0] + ) + shardings = list(jax.tree.leaves(tree_broadcast(shardings, tree))) + + names_to_load = [os.path.join(path, name.replace("/", "~")) + for name in names_to_load] + specs = [array_serial.get_tensorstore_spec(n) for n in names_to_load] + arrays = array_serial.run_deserialization(shardings, specs, concurrent_gb=64) + return tree_def.unflatten(arrays) + + +def steps(prefix, config, data_size=None, batch_size=None, total_steps=None, + default=ValueError): + """Gets duration named `prefix` out of `config` and converts it to steps. + + Using this function to access a configuration value that denotes some kind + of duration (eg training time, warmup, checkpoint frequency, ...) allows the + duration to be specified in terms of steps, epochs, examples, or percent of + training time, and converts any of these into steps, such that the training + code only deals with steps. + If the result is not an integer step number, it is rounded to the nearest one. + + Args: + prefix: The name of the duration to query. The actual config fields can + then be one of `prefix_steps`, `prefix_examples`, or `prefix_epochs`. + config: The dictionary (config) from which to read the duration. + data_size: The total number of training examples in one epoch. + batch_size: The number of examples processed per step. + total_steps: The total number of training steps to run. + default: The default value to return when no duration of the name `prefix` + is found in the `config`. Set to `ValueError` (the default) to raise an + error instead of returning a default value. + + Returns: + The number of steps from the config, or the default value. + + Raises: + ValueError if there is no such duration in the config and no default is set. + """ + # Be helpful and make sure only match one of the following suffixes. + suffixes = {"steps", "examples", "epochs", "percent"} + matches = { + f"{prefix}_{s}" + for s in suffixes + if (x := config.get(f"{prefix}_{s}")) is not None and x >= 0 + } + # Note that steps=0 is also a valid value (e.g. to only run evaluators). + assert len(matches) <= 1, f"Only one of '{matches}' should be defined." + + if f"{prefix}_steps" in matches: + return config[f"{prefix}_steps"] + + def to_integer(x): + # Round to nearest but always executed at least one step unless explictily + # asked for 0. E.g. total_epochs=0 vs total_epochs=0.0001 + return max(1, round(x)) if x else 0 + + if batch_size and f"{prefix}_examples" in matches: + return to_integer(config[f"{prefix}_examples"] / batch_size) + + if batch_size and data_size and f"{prefix}_epochs" in matches: + steps_per_epoch = data_size / batch_size + return to_integer(config[f"{prefix}_epochs"] * steps_per_epoch) + + if total_steps and f"{prefix}_percent" in matches: + pct = config[f"{prefix}_percent"] + assert 0.0 <= pct <= 1.0, ( # Be helpful, since it's not obvious. + f"Percents should lie in [0.0, 1.0], but {prefix}_percent is {pct}") + return to_integer(pct * total_steps) + + if default is ValueError: + raise ValueError( + f"Cannot convert {prefix} to steps, due to missing batch_size " + f"({batch_size}), data_size ({data_size}), total_steps ({total_steps})" + ", or corresponding entry in config:\n" + "\n".join(config.keys())) + + return default + + +def create_learning_rate_schedule( + total_steps, batch_size=None, data_size=None, + base=1.0, decay_type="stair", + scale_with_batchsize=False, **kw): + """Creates learning rate schedule, see (internal link). + + Args: + total_steps: The total number of steps to run. + batch_size: The global batch-size optionally used for scaling. + data_size: Number of examples in the training data (for epoch conversion). + base: The starting learning-rate (without warmup). + decay_type: 'linear' or 'cosine', 'rsqrt', 'stair'. + scale_with_batchsize: Whether or not to scale lr automatically. + **kw: extra arguments specific to individual decay_types. Also contains + declaration of `{warmup,cooldown}_{steps,epochs,examples}` that applies + on top of any/all decay_type. + + Returns: + A function learning_rate(step): float -> {"learning_rate": float}. + """ + + def to_steps(name, default=0): + return steps(name, kw, data_size, batch_size, total_steps, default=default) + + warmup_steps = to_steps("warmup") + cooldown_steps = to_steps("cooldown") + + # Early catch hard to backtrack errors due to warmup_steps >= total_steps, + # but let it run for 0 and 1 steps used to eval and debug runs. + assert (total_steps <= 1) or (warmup_steps < total_steps), ( + "warmup_steps is >= total_steps") + + def step_fn(step): + """Step to learning rate function.""" + lr = base + + # This implements the linear scaling rule following + # Goyal et al. at arxiv.org/abs/1706.02677. + # The reference batch size in literature is 256, so we scale the lr to + # adjust to the literature lr when bach_size changes. + if scale_with_batchsize: + lr = lr * batch_size / 256.0 + + progress = (step - warmup_steps) / float(total_steps - warmup_steps) + progress = jnp.clip(progress, 0.0, 1.0) + if decay_type in ("linear", "polynomial"): + power = kw.get("power", 1) + zero = kw.get("end", kw.get("linear_end", 0)) + lr = zero + (lr - zero) * (1.0 - progress) ** power + elif decay_type == "cosine": + lr = lr * 0.5 * (1. + jnp.cos(jnp.pi * progress)) + elif decay_type == "rsqrt": + # See (internal link) for details, especially how to set timescale + # and shift in order to continue smoothly when changing batch-size. + t = to_steps("timescale", default=kw.get("timescale", 10_000)) + shift = to_steps("shift", default=kw.get("shift", 0)) + lr = jnp.where( + warmup_steps <= step, + lr / jnp.sqrt(1 + (step + shift - warmup_steps) / t), # In decay + lr / jnp.sqrt(1 + shift / t)) # In warmup. + elif decay_type == "stair": + i = jnp.searchsorted(jnp.array(kw.get("steps", [])), step + 1) + lr = lr * jnp.take(jnp.array([1.0] + list(kw.get("mults", []))), i) + else: + raise ValueError(f"Unknown lr type {decay_type}") + + if warmup_steps: + lr = lr * jnp.minimum(1., step / warmup_steps) + if cooldown_steps: + lr = lr * jnp.minimum(1., (total_steps - step) / cooldown_steps) + + return jnp.asarray(lr, dtype=jnp.float32) + + return step_fn + + +def get_mixup(rng, p): + """Perform mixup https://arxiv.org/abs/1710.09412.""" + rng, rng_mixup = jax.random.split(rng) + a = jax.random.beta(rng_mixup, p, p) + a = jnp.maximum(a, 1.0 - a) # see (internal link) for the context. + def _mixup(*things, **more_things): + mix = lambda thing: a * thing + (1 - a) * jnp.roll(thing, shift=1, axis=0) + return rng, *jax.tree.map(mix, (things, more_things)) + return _mixup + + +# For backwards compatability with legacy code. +def mixup(rng, *things, p, **more_things): + return get_mixup(rng, p)(*things, **more_things) + + +def sync(): + """Syncs hosts and empties async computation queue.""" + x = reshard(np.ones(jax.device_count()), + jax.sharding.PositionalSharding(jax.devices())) + jax.jit(jnp.sum)(x).block_until_ready() + + +def check_and_compile_patterns(patterns): + """Validates and compiles a list of param-patterns. + + The validation consists of checking for common mistakes, currently only that + the pattern does not start with a slash, because unlike FLAX, our parameter + names don't start with a slash. + + Args: + patterns: a single (string) pattern (regex), or a list of patterns. + + Returns: + A list of compiled and verified regexes. + """ + if isinstance(patterns, str): + patterns = [patterns] + + assert isinstance(patterns, (list, tuple)), patterns + + def check_and_compile(pattern): + assert not pattern.startswith("/"), ( + f"Big vision parameter names never start with '/': '{pattern}") + return re.compile(pattern) + + return list(map(check_and_compile, patterns)) + + +def make_mask_trees(tree, patterns, *, log=None): + """Returns a boolean mask tree for every pattern (only first match).""" + compiled_patterns = check_and_compile_patterns(patterns) + + def matchfirst(name, _): + matches = [] + for pattern in compiled_patterns: + matches.append(not any(matches) and bool(pattern.fullmatch(name))) + if log is not None and True in matches and jax.process_index() == 0: + logging.info("%s: %s - matched by %s", log, name, + patterns[matches.index(True)]) + return np.array(matches) + + multimask = tree_map_with_names(matchfirst, tree) + return [ + jax.tree.map(lambda matches, i=idx: matches[i], multimask) + for idx in range(len(patterns)) + ] + + +@contextlib.contextmanager +def profile(name, ttl=3 * 365 * 24 * 3600, noop=False): + if not noop: + sess = startstop_prof_at_steps(None, name=name, ttl=ttl) + yield + if not noop: + startstop_prof_at_steps(sess, name=name, ttl=ttl) + + +def startstop_prof(sess, step=None, first_step=0, + log_steps=1, surround=10, **kw): + """Runs the profiler for `surround` steps around the next `log_steps`.""" + first_log = first_step + log_steps - (first_step % log_steps) + # don't start before first! + start = max(first_log - surround//2, first_step + 1) + return startstop_prof_at_steps(sess, step, start, start + surround, **kw) + + +def startstop_prof_at_steps( + sess, step=None, first_step=None, last_step=None, + name="steps", ttl=3 * 365 * 24 * 3600): + del sess, step, first_step, last_step, name, ttl + pass # TODO: implement using `jax.profiler` API. Needs workdir. + + +# This is a very minimal variant for open-sourcing. Our internal code makes use +# of multiple internal logging tools instead. +class BigVisionMetricWriter: + """A class for logging metrics.""" + + def __init__(self, xid=-1, wid=-1, workdir=None, config=None): + self.step_start(0) + if jax.process_index() != 0: return # Only one host shall write stuff. + + self.pool = multiprocessing.pool.ThreadPool(1) # 1 is important here. + self.fname = None + if workdir: + if xid != -1 and wid != -1: + self.fname = os.path.join(workdir, + f"big_vision_{xid}_{wid}_metrics.txt") + else: + self.fname = os.path.join(workdir, "big_vision_metrics.txt") + if config: + with gfile.GFile(os.path.join(workdir, "config.json"), "w") as f: + f.write(config.to_json()) + + def step_start(self, step): + self.step = step + self.step_metrics = {} + + def measure(self, name, value): + """Logs the metric value.""" + if jax.process_index() != 0: return # Only one host shall write stuff. + + # Convenience for accepting scalar np/DeviceArrays, as well as N-d single + # scalars, like [[[123]]] or similar, avoiding silly mistakes. + value = np.array(value).squeeze() + + # If the value is a scalar, we keep it in mind to append a line to the logs. + # If it has any structure, we instead just log its shape. + value = float(value) if value.ndim == 0 else value.shape + + logging.info(f"\u001b[35m[{self.step}]\u001b[0m {name} = {value}") + logging.flush() + self.step_metrics[name] = value + + return value # Just for convenience + + def step_end(self): + """Ends a training step, write its full row.""" + if not self.step_metrics: return + + def write(metrics): + with gfile.GFile(self.fname, "a") as f: + f.write(json.dumps({"step": self.step, **metrics}) + "\n") + + if self.fname: + self.pool.apply(lambda: None) # Potentially wait for past writes. + self.pool.apply_async(write, (self.step_metrics,)) + + def close(self): + self.step_end() + if jax.process_index() == 0: + self.pool.close() + self.pool.join() + + +def maybe_cleanup_workdir(workdir, cleanup, info): + """Potentially removes workdirs at end of run for cleanup.""" + if not workdir: + return + + if not cleanup: + info("Logs/checkpoints are in %s", workdir) + elif jax.process_index() == 0: + gfile.rmtree(workdir) + try: # Only need this on the last work-unit, if already empty. + gfile.remove(os.path.join(workdir, "..")) + except tf.errors.OpError: + pass + + +def tree_broadcast(prefix, target): + """Broadcasts a prefix tree to a full tree. + + Input-output examples: + 1. prefix: {"x": 10, "y": 20} + target: {"x": {"a": 1, "b": 2}, "y": 3} + + Result: {"x": {"a": 10, "b": 10}, "y": 20} + + 2. prefix: 100 + target: {"x": {"a": 1, "b": 2}, "y": 3} + + Result: {"x": {"a": 100, "b": 100}, "y": 100} + + 3. prefix: {"x": 10} + target: {"x": {"a": 1, "b": 2}, "y": 3} + + Result: ValueError + + Args: + prefix: prefix pytree. + target: boradcast target for a prefix tree. + + Returns: + prefix tree broadcasted to a target tree. + """ + def _broadcast(leaf, subtree): + return jax.tree.map(lambda _: leaf, subtree) + return jax.tree.map(_broadcast, prefix, target) + + +def reshard(tree, shardings): + """Take an arbitrarily* sharded pytree and shard it according to `shardings`. + + This is a no-op for tree elements which are already sharded as requested. + + *Arrays that are fully addressable (for example, CPU arrays) are assumed to be + identical (i.e. replicated) across hosts. + + *It does not work if an element of `tree` is not fully-addressable, unless its + sharding is already consistent with the target sharding. + If this is needed, please ping lbeyer@ or akolesnikov@. + + Args: + tree: a pytree of arrays. + shardings: a (prefix) pytree of jax array shardings. + Returns: + A pytree of global jax arrays that follows provided shardings. + """ + def _make_global_arr(x, shard, shape): + # Avoid unnecessary copies and transfers: + if hasattr(x, "sharding") and x.sharding.is_equivalent_to(shard, len(shape)): # pylint: disable=line-too-long + return x + if not getattr(x, "is_fully_addressable", True): + raise RuntimeError("Trying to reshard a non-fully-addressable array. " + "Please see the doc-comment for detailed explanation.") + x = jax.device_get(x) # Might be on local devices. + xs = [jax.device_put(x[s], device=d) + for d, s in shard.addressable_devices_indices_map(shape).items()] + return jax.make_array_from_single_device_arrays(shape, shard, xs) + + shapes = jax.tree.map(np.shape, tree) + shardings = tree_broadcast(shardings, tree) + return jax.tree.map(_make_global_arr, tree, shardings, shapes) + + +def put_cpu(x): + """Places array/pytree on a CPU device.""" + return jax.device_put(x, jax.local_devices(backend="cpu")[0]) + + +def make_fsarray_from_local_slice(local_slice, global_devices): + """Create a fully-sharded global device array from local host arrays. + + Args: + local_slice: Something convertible to a numpy array (eg also TF tensors) + that is this host's slice of the global array. + global_devices: The list of global devices. Needed for consistent ordering. + + Returns: + The global on-device array which consists of all local slices stacked + together in the order consistent with the devices. + """ + mesh = jax.sharding.Mesh(global_devices, ("devices",)) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("devices")) + local_ds = mesh.local_devices + + x = np.asarray(memoryview(local_slice)) # No-copy: http://(internal link) + xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds) + + global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:]) + return jax.make_array_from_single_device_arrays(global_shape, sharding, xs) + + +def get_local_slice_from_fsarray(global_array): + """Return numpy array for the host-local slice of fully-sharded array. + + Args: + global_array: JAX array, globally sharded on devices across hosts. + + Returns: + NumPy array that holds the part of `global_array` that is held by the + devices on the host that calls this function. + """ + # For now, for simplicity, we only implement slicing along the first axis. + for shard in global_array.addressable_shards: + assert all(idx == slice(None) for idx in shard.index[1:]), ( + f"global_array is sharded along non-first dimensions:\n{shard.index}") + + # Get the shards back in the same order in which the global array was created + # in the first place. This makes sure it's consistent with other things in the + # batch, for example (assuming the whole batch is consistent). + m = {s.device: s for s in global_array.addressable_shards} + local_shards = [m[d] for d in global_array.sharding.mesh.local_devices] + return np.concatenate([jax.device_get(s.data) for s in local_shards], axis=0) + + +def assert_local_slices_same(*global_arrays): + """Check whether all `global_arrays` have local slices at the same indices.""" + slices = [ + tuple( + tuple((idx.start, idx.end, idx.step) for idx in s.index) + for s in a.addressable_shards) + for a in global_arrays] + assert len(set(slices)) == 1, f"Not all slices are the same: {slices}" + + +# TODO: remove this logic when the +# issue is github fixed https://github.com/google/jax/issues/15600. +def jit_cpu(**extra_kwargs): + def _decorator(fun): + def _wrapped(*args, **kwargs): + sh = jax.sharding.SingleDeviceSharding( + jax.local_devices(backend="cpu")[0] + ) + return jax.jit(fun, **extra_kwargs, out_shardings=sh)(*args, **kwargs) + return _wrapped + return _decorator + + +def create_device_mesh( + config_mesh, + *, + allow_split_physical_axes=False, +): + """Returns a JAX device mesh. + + Args: + config_mesh: A list of tuples of (axis_name, axis_size). It is advised to + sort the axis in increasing order of network communication intensity. + allow_split_physical_axes: Whether to allow splitting physical axes. + """ + devices = jax.devices() + mesh_axes, mesh_size = tuple(zip(*config_mesh)) + # Because jax.utils do not support `-1` shape size. + mesh_size = np.array(devices).reshape(mesh_size).shape + device_mesh = mesh_utils.create_device_mesh( + mesh_size, + devices=devices, + allow_split_physical_axes=allow_split_physical_axes) + return jax.sharding.Mesh(device_mesh, mesh_axes) diff --git a/big_vision/utils_test.py b/big_vision/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..20680ee689262ef10e7a69d0446cc35f0fcf08bf --- /dev/null +++ b/big_vision/utils_test.py @@ -0,0 +1,360 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for utils.""" + +from functools import partial +import os + +from absl.testing import parameterized +from big_vision import utils +import chex +import flax +import jax +from jax.experimental.array_serialization import serialization as array_serial +import jax.numpy as jnp +import numpy as np +import tensorflow as tf + +from tensorflow.io import gfile + + +NDEV = 4 + + +def setUpModule(): + chex.set_n_cpu_devices(NDEV) + + +class PadShardUnpadTest(chex.TestCase, tf.test.TestCase): + BATCH_SIZES = [NDEV, NDEV + 1, NDEV - 1, 5 * NDEV, 5 * NDEV + 1, 5 * NDEV - 1] + DTYPES = [np.float32, np.uint8, jax.numpy.bfloat16, np.int32] + + def tearDown(self): + chex.clear_trace_counter() + super().tearDown() + + @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) + def test_basics(self, dtype, bs): + # Just tests that basic calling works without exploring caveats. + @partial(utils.pad_shard_unpad, static_argnums=()) + def add(a, b): + return a + b + + x = jnp.arange(bs, dtype=dtype) + y = add(x, 10 * x) + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x)) + + @parameterized.parameters(DTYPES) + def test_min_device_batch_avoids_recompile(self, dtype): + @partial(utils.pad_shard_unpad, static_argnums=()) + @jax.jit + @chex.assert_max_traces(n=1) + def add(a, b): + return a + b + + chex.clear_trace_counter() + + for bs in self.BATCH_SIZES: + x = jnp.arange(bs, dtype=dtype) + y = add(x, 10 * x, min_device_batch=9) # pylint: disable=unexpected-keyword-arg + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x)) + + @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) + def test_static_argnum(self, dtype, bs): + @partial(utils.pad_shard_unpad, static_argnums=(1,)) + def add(a, b): + return a + b + + x = jnp.arange(bs, dtype=dtype) + y = add(x, dtype(10)) + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10)) + + @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) + def test_static_argnames(self, dtype, bs): + # In this test, leave static_argnums at the default value too, in order to + # test the default/most canonical path where `params` are the first arg. + @partial(utils.pad_shard_unpad, static_argnames=('b',)) + def add(params, a, *, b): + return params * a + b + + x = jnp.arange(bs, dtype=dtype) + y = add(dtype(5), x, b=dtype(10)) + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10)) + + +class TreeTest(tf.test.TestCase): + + def setUp(self): + super().setUp() + + self.d1 = {'w1': 1, 'w2': 2, 'w34': (3, 4)} + self.d1_flat = [1, 2] + self.d1_flat_jax = jax.tree.flatten(self.d1)[0] + self.d1_named_flat = [('w1', 1), ('w2', 2), ('w34/0', 3), ('w34/1', 4)] + self.d1_named_flat_jax = [('w1', 1), ('w2', 2), ('w34/0', 3), ('w34/1', 4)] + + self.d2 = {'conv1': {'kernel': 0, 'bias': 1}, + 'conv2': {'kernel': 2, 'bias': 3}} + self.d2_flat = [1, 0, 3, 2] + self.d2_flat_jax = jax.tree.flatten(self.d2)[0] + self.d2_named_flat = [('conv1/bias', 1), ('conv1/kernel', 0), + ('conv2/bias', 3), ('conv2/kernel', 2)] + self.d2_named_flat_jax = [('conv1/bias', 1), ('conv1/kernel', 0), + ('conv2/bias', 3), ('conv2/kernel', 2)] + self.d2_named_flat_inner = [ + ('conv1/bias', 1), ('conv1/kernel', 0), ('conv1', self.d2['conv1']), + ('conv2/bias', 3), ('conv2/kernel', 2), ('conv2', self.d2['conv2']), + ('', self.d2), + ] + + # This is a very important testcase that checks whether we correctly + # recover jax' traversal order, even though our custom traversal may not + # be consistent with jax' traversal order. In particular, jax traverses + # FlaxStruct in the order of attribute definition, while our custom + # traversal is alphabetical. + @flax.struct.dataclass + class FlaxStruct(): + v3: float + v2: int + v1: str + self.d3 = {'a': 0, 'flax': FlaxStruct(2.0, 1, 's')} + self.d3_flat = [0, 1, 2.0, 's'] + self.d3_flat_jax = jax.tree.flatten(self.d3)[0] + self.d3_named_flat = [ + ('a', 0), ('flax/v1', 's'), ('flax/v2', 1), ('flax/v3', 2.0)] + self.d3_named_flat_jax = [ + ('a', 0), ('flax/v3', 2.0), ('flax/v2', 1), ('flax/v1', 's')] + + def test_traverse_with_names(self): + names_and_vals = list(utils._traverse_with_names(self.d1)) + self.assertEqual(names_and_vals, self.d1_named_flat) + + names_and_vals = list(utils._traverse_with_names(self.d2)) + self.assertEqual(names_and_vals, self.d2_named_flat) + + names_and_vals = list(utils._traverse_with_names( + self.d2, with_inner_nodes=True)) + self.assertEqual(names_and_vals, self.d2_named_flat_inner) + + names_and_vals = list(utils._traverse_with_names(self.d3)) + self.assertEqual(names_and_vals, self.d3_named_flat) + + def test_tree_flatten_with_names(self): + names_and_vals = utils.tree_flatten_with_names(self.d1)[0] + self.assertEqual(names_and_vals, self.d1_named_flat_jax) + self.assertEqual([x for _, x in names_and_vals], self.d1_flat_jax) + + names_and_vals = utils.tree_flatten_with_names(self.d2)[0] + self.assertEqual(names_and_vals, self.d2_named_flat_jax) + self.assertEqual([x for _, x in names_and_vals], self.d2_flat_jax) + + names_and_vals = utils.tree_flatten_with_names(self.d3)[0] + self.assertEqual(names_and_vals, self.d3_named_flat_jax) + self.assertEqual([x for _, x in names_and_vals], self.d3_flat_jax) + + def test_tree_map_with_names(self): + d1 = utils.tree_map_with_names( + lambda name, x: -x if 'w2' in name else x, self.d1) + self.assertEqual(d1, {'w1': 1, 'w2': -2, 'w34': (3, 4)}) + + d1 = utils.tree_map_with_names( + lambda name, x1, x2: x1 + x2 if 'w2' in name else x1, self.d1, self.d1) + self.assertEqual(d1, {'w1': 1, 'w2': 4, 'w34': (3, 4)}) + + def test_recover_tree(self): + keys = ['a/b', 'a/c/x', 'a/c/y', 'd'] + values = [0, 1, 2, 3] + self.assertEqual(utils.recover_tree(keys, values), + {'a': {'b': 0, 'c': {'x': 1, 'y': 2}}, 'd': 3}) + + def test_make_mask_trees(self): + F, T = False, True # pylint: disable=invalid-name + tree = {'a': {'b': 0, 'x': 1}, 'b': {'x': 2, 'y': 3}} + msk1 = {'a': {'b': F, 'x': T}, 'b': {'x': T, 'y': F}} + msk2 = {'a': {'b': F, 'x': F}, 'b': {'x': F, 'y': T}} + # Note that 'b' matches '^b' only and not '.*/b'. + # Also note that "b/x" is matched by rule 1 only (because it comes first). + self.assertEqual( + utils.make_mask_trees(tree, ('.*/x', 'b/.*')), [msk1, msk2]) + + def test_tree_get(self): + tree = {'a': {'b': 0, 'x': 1}, 'b': {'x': 2, 'y': 3}} + self.assertEqual(utils.tree_get(tree, 'a/b'), 0) + self.assertEqual(utils.tree_get(tree, 'a/x'), 1) + self.assertEqual(utils.tree_get(tree, 'b/x'), 2) + self.assertEqual(utils.tree_get(tree, 'b/y'), 3) + self.assertEqual(utils.tree_get(tree, 'a'), tree['a']) + self.assertEqual(utils.tree_get(tree, 'b'), tree['b']) + + def test_tree_replace(self): + tree = {'a': {'b': 2, 'c': 3}, 'c': 4} + replacements = { + 'a/b': 'a/b/x', # replaces 'a/b' with 'a/b/x' + '.*c': 'C', # replaces 'c' with 'C' ('a/c' is removed) + 'C': 'D', # replaces 'C' (which was 'c') with 'D' + '.*/c': None, # removes 'a/c' + } + tree2 = utils.tree_replace(tree, replacements) + self.assertEqual(tree2, {'D': 4, 'a': {'b': {'x': 2}}}) + + def test_tree_compare(self): + tree1_only, tree2_only, dtype_shape_mismatch = utils.tree_compare( + {'a': {'b': jnp.array(2), 'c': jnp.array(3)}}, + {'a': {'B': jnp.array(2), 'c': jnp.array(3.)}}, + ) + self.assertEqual(tree1_only, {'a/b'}) + self.assertEqual(tree2_only, {'a/B'}) + self.assertEqual( + dtype_shape_mismatch, + {'a/c': [(jnp.dtype('int32'), ()), (jnp.dtype('float32'), ())]}) + + +class StepConversionTest(parameterized.TestCase, tf.test.TestCase): + + @parameterized.named_parameters( + ('nice_steps', 1000, None, None, dict(foo_steps=3), 3), + ('nice_epochs', 1000, 100, None, dict(foo_epochs=3), 30), + ('nice_examples', None, 100, None, dict(foo_examples=300), 3), + ('nice_percent', None, None, 10, dict(foo_percent=0.30), 3), + ('ignore_neg', 1000, 100, 10, dict(foo_steps=-1, foo_epochs=-1, + foo_examples=-1, foo_percent=0.30), 3), + ('zero_steps', None, None, 10, dict(foo_percent=0.0), 0), + ('offbyone_steps', 1001, None, None, dict(foo_steps=3), 3), + ('offbyone_epochs', 1001, 100, None, dict(foo_epochs=3), 30), + ('offbyone_examples', None, 101, None, dict(foo_examples=300), 3), + ('offbyone_percent', None, None, 11, dict(foo_percent=0.30), 3), + ) + def test_steps(self, data_size, batch_size, total, cfg, expected): + # Correct default usage: + step = utils.steps('foo', cfg, data_size=data_size, batch_size=batch_size, + total_steps=total) + self.assertEqual(step, expected) + + # Inexitent entry: + with self.assertRaises(ValueError): + step = utils.steps('bar', cfg, data_size=data_size, batch_size=batch_size, + total_steps=total) + step = utils.steps('bar', cfg, data_size=data_size, batch_size=batch_size, + total_steps=total, default=1234) + self.assertEqual(step, 1234) + + +class CreateLearningRateScheduleTest(parameterized.TestCase, tf.test.TestCase): + + @parameterized.named_parameters( + ('linear', 'linear', {}, 13, .5), + ('polynomial', 'polynomial', {'end': .1, 'power': 2}, 13, .325), + ('cosine', 'cosine', {}, 13, .5), + ('rsqrt', 'rsqrt', {'timescale': 1}, 13, 0.3333333), + ('stair_5', 'stair', {'steps': [10], 'mults': [.5]}, 5, 1.), + ('stair_10', 'stair', {'steps': [10], 'mults': [.5]}, 10, .5), + ('warmup_before', 'rsqrt', {'timescale': 1}, 3, .6), + ('cooldown_after', 'rsqrt', {'timescale': 1}, 20, .05), + ) + def test_schedule(self, decay_type, extra_kwargs, step, expected_lr): + lr_fn = utils.create_learning_rate_schedule( + total_steps=21, + batch_size=512, + base=.5, + decay_type=decay_type, + scale_with_batchsize=True, + warmup_steps=5, + cooldown_steps=5, + **extra_kwargs) + lr = lr_fn(step) + self.assertAlmostEqual(lr, expected_lr) + + +class CheckpointTest(tf.test.TestCase): + + def setup(self): + gacm = array_serial.GlobalAsyncCheckpointManager() + + save_path = os.path.join(self.create_tempdir('workdir'), 'checkpoint.bv') + x = utils.put_cpu(np.array([1, 2, 3, 4])) + y = utils.put_cpu(np.array([5, 6, 7, 8])) + ckpt = {'x': x, 'y': {'z': y}} + + sharding = jax.sharding.SingleDeviceSharding( + jax.local_devices(backend='cpu')[0] + ) + shardings = jax.tree.map(lambda _: sharding, ckpt) + + return gacm, save_path, ckpt, shardings + + def test_save_and_load(self): + gacm, save_path, ckpt, shardings = self.setup() + step = 100 + utils.save_checkpoint_ts(gacm, ckpt, save_path, step, keep=True) + gacm.wait_until_finished() + ckpt_loaded = utils.load_checkpoint_ts(save_path, + tree=ckpt, shardings=shardings) + chex.assert_trees_all_equal(ckpt_loaded, ckpt) + + save_path_step = f'{save_path}-{step:09d}' + ckpt_loaded_step = utils.tsload(save_path_step, shardings=shardings) + chex.assert_trees_all_equal(ckpt_loaded_step, ckpt) + + def test_save_and_partial_load(self): + gacm, save_path, ckpt, shardings = self.setup() + utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100) + gacm.wait_until_finished() + _ = shardings.pop('x'), ckpt.pop('x') + ckpt_loaded = utils.load_checkpoint_ts(save_path, + tree=ckpt, shardings=shardings) + chex.assert_trees_all_equal(ckpt_loaded, ckpt) + + def test_save_and_cpu_load(self): + gacm, save_path, ckpt, _ = self.setup() + utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100) + gacm.wait_until_finished() + ckpt_loaded = utils.load_checkpoint_ts(save_path) + chex.assert_trees_all_equal(ckpt_loaded, ckpt) + + def test_save_and_partial_cpu_load(self): + gacm, save_path, ckpt, _ = self.setup() + utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100) + gacm.wait_until_finished() + ckpt.pop('y') + ckpt_loaded = utils.load_checkpoint_ts(save_path, regex='x.*') + chex.assert_trees_all_equal(ckpt_loaded, ckpt) + + def test_keep_deletes(self): + def x(tree, factor): # x as in "times" for multiplying. + return jax.tree.map(lambda a: a * factor, tree) + + gacm, save_path, ckpt, _ = self.setup() + utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100, keep=False) + utils.save_checkpoint_ts(gacm, x(ckpt, 2), save_path, step=200, keep=True) + utils.save_checkpoint_ts(gacm, x(ckpt, 3), save_path, step=300, keep=False) + gacm.wait_until_finished() + ckpt_loaded_200 = utils.tsload(f'{save_path}-{200:09d}') + chex.assert_trees_all_equal(ckpt_loaded_200, x(ckpt, 2)) + ckpt_loaded_300 = utils.tsload(f'{save_path}-{300:09d}-tmp') + chex.assert_trees_all_equal(ckpt_loaded_300, x(ckpt, 3)) + ckpt_loaded_last = utils.load_checkpoint_ts(save_path) + chex.assert_trees_all_equal(ckpt_loaded_last, x(ckpt, 3)) + with self.assertRaises(Exception): # Can different types depending on fs. + _ = utils.tsload(f'{save_path}-{100:09d}') + # Test that ckpt@100 was deleted + self.assertFalse(gfile.exists(f'{save_path}-{100:09d}-tmp')) + + +if __name__ == '__main__': + tf.test.main() diff --git a/build/lib/scenic/__init__.py b/build/lib/scenic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/lib/scenic/app.py b/build/lib/scenic/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/lib/scenic/main.py b/build/lib/scenic/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ckpts/clip_vit_l14_with_masks_6c17944 b/ckpts/clip_vit_l14_with_masks_6c17944 new file mode 100644 index 0000000000000000000000000000000000000000..215495e23ca242ae84b6ea1dc4a756c3a2ac602c --- /dev/null +++ b/ckpts/clip_vit_l14_with_masks_6c17944 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26bb25bb66e747705143a35f76f5294114746e531436949d96933019cd17b2e6 +size 1048576 diff --git a/ckpts/owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05_209b65b b/ckpts/owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05_209b65b new file mode 100644 index 0000000000000000000000000000000000000000..21d2912d4b08d035e8b5d16f2852288f3a67b559 --- /dev/null +++ b/ckpts/owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05_209b65b @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a05870e6eaf244a52382fb6935b955ca7e6e873181a0583d552a8330d228b900 +size 1048576 diff --git a/ckpts/owl2-l14-1008-st-ngrams-ft-lvisbase-ens-cold-weight-04_8ca674c b/ckpts/owl2-l14-1008-st-ngrams-ft-lvisbase-ens-cold-weight-04_8ca674c new file mode 100644 index 0000000000000000000000000000000000000000..8e95f93bff54a951abaec26cad050ec874249a7d --- /dev/null +++ b/ckpts/owl2-l14-1008-st-ngrams-ft-lvisbase-ens-cold-weight-04_8ca674c @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a276d61fb7c1ec7bbf68fa7d5d34a8032b6aa4f15170c1125208f77133b9c17 +size 1048576 diff --git a/images/scenic_design.jpg b/images/scenic_design.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f330bf75c9a8433a833ca288a41c0f219a9d8ffb --- /dev/null +++ b/images/scenic_design.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b341d7d1fa31b679c346834e568221da3616196040d302fd65aa7d417010ee6 +size 668643 diff --git a/images/scenic_logo.jpg b/images/scenic_logo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..76fbd8bfa116986e38b96293a7d4c7132589ad88 --- /dev/null +++ b/images/scenic_logo.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e74593754532c72b37f9c16b386609d794467e2baf0c5532266e69e8c0c954b +size 163943 diff --git a/images/scenic_logo.pdf b/images/scenic_logo.pdf new file mode 100644 index 0000000000000000000000000000000000000000..c22604a5800abf0712331c3d6e85543315b7ef7b Binary files /dev/null and b/images/scenic_logo.pdf differ diff --git a/images/scenic_logo.png b/images/scenic_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..9fb1f0d03dc18fcbab60a6b01073e1997be6b81e Binary files /dev/null and b/images/scenic_logo.png differ diff --git a/owlv2_helper.py b/owlv2_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..fb8768be029d2f4233d610d8756363885ba7fdb0 --- /dev/null +++ b/owlv2_helper.py @@ -0,0 +1,155 @@ +import os +from matplotlib import pyplot as plt +import skimage +from skimage import io as skimage_io +import numpy as np +import re +import cv2 + + +def rescale_detection_box(boxes, image): + h_img, w_img, _ = image.shape + size = max(h_img, w_img) + + pad_h = size - h_img + pad_w = size - w_img + + recovered_boxes = [] + for box in boxes: + cx, cy, w, h = box + cx = cx * size + cy = cy * size + w = w * size + h = h * size + + # if cx < 0 or cx > w_img or cy < 0 or cy > h_img: + # continue; + + x1 = cx - w / 2 + y1 = cy - h / 2 + x2 = cx + w / 2 + y2 = cy + h / 2 + recovered_boxes.append((x1, y1, x2, y2)) + return recovered_boxes + + + + +def read_images(image_dir): + images = [] + filenames = [p for p in os.listdir(image_dir) if os.path.splitext(p)[-1].lower() in [".png", ".jpg", ".jpeg",]] + filenames.sort(key=lambda p: os.path.splitext(p)[0]) + for filename in filenames: + file_path = os.path.join(image_dir, filename) + image_uint8 = skimage_io.imread(file_path) + image = image_uint8.astype(np.float32) / 255.0 + images.append(image) + return images, filenames + + + + +def preprocess_images(images, model_input_size): + processed_images = [] + for image in images: + # Pad image to square + h, w, d = image.shape + size = max(h, w) + image_padded = np.pad(image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5) + # Resize image to fit model's input size + image_resized = skimage.transform.resize( + image_padded, + (model_input_size, model_input_size), + anti_aliasing=True, + ) + processed_images.append(image_resized) + # Shape: (b, h, w, d) + return np.array(processed_images, dtype=np.float32) + + + + +def too_small(bbox, threshold=400): + x1, y1, x2, y2 = bbox + width = max(0, x2 - x1) + height = max(0, y2 - y1) + area = width * height + # Return True if area is too small + return area < threshold + + +def too_large(bbox, image, threshold=0.9): + x1, y1, x2, y2 = bbox + + bbox_width = x2 - x1 + bbox_height = y2 - y1 + bbox_area = bbox_width * bbox_height + + image_height, image_width = image.shape[:2] + image_area = image_width * image_height + + area_ratio = bbox_area / image_area + return area_ratio >= threshold + + + + +def plot_bboxes_on_orig_image(image, boxes, output_path): + plt.clf() + plt.imshow(image) + plt.axis('off') + + for box in boxes: + x1, y1, x2, y2 = box + plt.plot( + [x1, x2, x2, x1, x1], + [y1, y1, y2, y2, y1], + linewidth=0.8, alpha=0.6 + ) + + plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1, dpi=300) + plt.close() + print(f" Done! Visualization saved to {output_path}") + + + + +def compute_iou(box1, box2): + x1_inter = max(box1[0], box2[0]) + y1_inter = max(box1[1], box2[1]) + x2_inter = min(box1[2], box2[2]) + y2_inter = min(box1[3], box2[3]) + inter_width = max(0, x2_inter - x1_inter) + inter_height = max(0, y2_inter - y1_inter) + inter_area = inter_width * inter_height + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union_area = box1_area + box2_area - inter_area + return inter_area / union_area if union_area > 0 else 0 + + + + +def remove_overlapping_bboxes(bboxes, iou_threshold=0.7): + if not bboxes: + return [] + bboxes = sorted(bboxes, key=lambda x: (x[2] - x[0]) * (x[3] - x[1]), reverse=True) + keep = [] + for bbox in bboxes: + should_keep = True + for kept_bbox in keep: + if compute_iou(bbox, kept_bbox) > iou_threshold: + should_keep = False + break + if should_keep: + keep.append(bbox) + return keep + + + + +def get_centroid(bbox): + x1, y1, x2, y2 = bbox + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + return (int(cx), int(cy)) \ No newline at end of file diff --git a/owlv2_helper_functions.py b/owlv2_helper_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf991db19325bfe6e777daba431b4877e3daa4e --- /dev/null +++ b/owlv2_helper_functions.py @@ -0,0 +1,290 @@ +import os +from matplotlib import pyplot as plt +import skimage +from skimage import io as skimage_io +import numpy as np +import re + + +def rescale_detection_box(boxes, image): + h_img, w_img, _ = image.shape + size = max(h_img, w_img) + + pad_h = size - h_img + pad_w = size - w_img + + recovered_boxes = [] + for box in boxes: + cx, cy, w, h = box + cx = cx * size + cy = cy * size + w = w * size + h = h * size + + # if cx < 0 or cx > w_img or cy < 0 or cy > h_img: + # continue; + + x1 = cx - w / 2 + y1 = cy - h / 2 + x2 = cx + w / 2 + y2 = cy + h / 2 + recovered_boxes.append((x1, y1, x2, y2)) + return recovered_boxes + + + +def plot_boxes_on_image(image, text_queries, + scores, boxes, labels, + filename, score_threshold, + output_dir): + + colors = ['red', 'green', 'blue', 'orange', 'purple', 'pink', 'cyan', 'magenta', 'lightblue', 'darkorange', 'darkgreen', 'darkred', 'lavender', 'brown', 'gray', 'black'] + + # 显示原始图片 + plt.clf() + plt.imshow(image) + plt.axis('off') + + # 绘制边界框 + for score, box, label in zip(scores, boxes, labels): + if score < score_threshold: + continue; + + x1, y1, x2, y2 = box + # print(f"box coord: {[x1, y1, x2, y2]}") + plt.plot( + [x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], + color=colors[label], linewidth=0.6, alpha=0.6 + ) + plt.text( + x1, y2 + 0.015, + f'{text_queries[label]}: {score:1.2f}', + ha='left', va='top', color=colors[label], fontsize=6, + bbox={'facecolor': 'white', 'edgecolor': colors[label], 'boxstyle': 'square,pad=.3', 'alpha': 0.5} + ) + # 保存图片到指定路径 OUTPUT_DIR + output_path = os.path.join(output_dir, filename) + plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1, dpi=300) + print(f"Image with boxes saved to {output_path}") + + + +def image_based_plot_boxes_on_image(image, text_queries, scores, boxes, filename,output_dir): + colors = ['red', 'green', 'blue', 'orange', 'cyan', 'magenta', 'lightblue', 'darkorange', 'lavender'] + + plt.clf() + plt.imshow(image) + plt.axis('off') + for score, box, text_query, color in zip(scores, boxes, text_queries, colors): + x1, y1, x2, y2 = box + plt.plot( + [x1, x2, x2, x1, x1], + [y1, y1, y2, y2, y1], + color=color, linewidth=1 + ) + plt.text( + x1, y2 + 0.015, + f'{text_query}: {score:1.2f}', + ha='left', va='top', color=color, fontsize=6, + bbox={'facecolor': 'white', 'edgecolor': color, 'boxstyle': 'square,pad=.3','alpha': 0.5} + ) + output_path = os.path.join(output_dir, filename) + plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1, dpi=300) + print(f"Image with boxes saved to {output_path}") + + + +def get_iou(bbox1, bbox2): + # 分别解包 bbox1 和 bbox2 的坐标 + x1_min, y1_min, x1_max, y1_max = bbox1 + x2_min, y2_min, x2_max, y2_max = bbox2 + + # 计算交集的顶点坐标 + inter_x_min = max(x1_min, x2_min) + inter_y_min = max(y1_min, y2_min) + inter_x_max = min(x1_max, x2_max) + inter_y_max = min(y1_max, y2_max) + + # 计算交集的宽度和高度(确保为非负值) + inter_width = max(0, inter_x_max - inter_x_min) + inter_height = max(0, inter_y_max - inter_y_min) + inter_area = inter_width * inter_height + + # 计算每个边界框的面积 + area1 = (x1_max - x1_min) * (y1_max - y1_min) + area2 = (x2_max - x2_min) * (y2_max - y2_min) + + # 计算并集面积 + union_area = area1 + area2 - inter_area + + # 计算 IOU + iou = inter_area / union_area if union_area > 0 else 0 + return iou + + + + +def read_images(image_dir): + images = [] + filenames = sorted(os.listdir(image_dir)) + for filename in filenames: + file_path = os.path.join(image_dir, filename) + image_uint8 = skimage_io.imread(file_path) + image = image_uint8.astype(np.float32) / 255.0 + images.append(image) + return images, filenames + + + +def preprocess_images(images, model_input_size): + processed_images = [] + for image in images: + # Pad image to square + h, w, d = image.shape + size = max(h, w) + image_padded = np.pad(image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5,) + # Resize image to fit model's input size + image_resized = skimage.transform.resize( + image_padded, + (model_input_size, model_input_size), + anti_aliasing=True, + ) + processed_images.append(image_resized) + # Shape: (b, h, w, d) + return np.array(processed_images, dtype=np.float32) + + + +def prepare_images(image_dir, model_input_size): + filenames = sorted(os.listdir(image_dir)) + + images = [] + for filename in filenames: + file_path = os.path.join(image_dir, filename) + image_uint8 = skimage_io.imread(file_path) + image = image_uint8.astype(np.float32) / 255.0 + + # Pad image to square + h, w, d = image.shape + size = max(h, w) + image_padded = np.pad( + image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5 + ) + + # Resize image to fit model's input size + image_resized = skimage.transform.resize( + image_padded, + (model_input_size, model_input_size), + anti_aliasing=True, + ) + images.append(image_resized) + + # Shape: (b, h, w, d) + return np.array(images, dtype=np.float32), filenames + + + +def plot_bbox_on_image(image, boxes, objectnesses, threshold, output_file): + fig, ax = plt.subplots(1, 1, figsize=(8, 8)) + ax.imshow(image, extent=(0, 1, 1, 0)) + ax.set_axis_off() + + for i, (box, objectness) in enumerate(zip(boxes, objectnesses)): + if objectness < threshold: + continue + + index = i + cx, cy, w, h = box + ax.plot( + [cx - w / 2, cx + w / 2, cx + w / 2, cx - w / 2, cx - w / 2], + [cy - h / 2, cy - h / 2, cy + h / 2, cy + h / 2, cy - h / 2], + color='lime', + ) + + ax.text( + cx - w / 2 + 0.015, + cy + h / 2 - 0.015, + f'Index {i}: {objectness:1.2f}', + ha='left', + va='bottom', + color='black', + bbox={ + 'facecolor': 'white', + 'edgecolor': 'lime', + 'boxstyle': 'square,pad=.3', + }, + ) + + ax.set_xlim(0, 1) + ax.set_ylim(1, 0) + ax.set_title(f'Top objects by objectness') + + # 保存图片到指定路径 + plt.savefig(output_file, bbox_inches='tight', dpi=300) + plt.close() # 关闭图像以释放内存 + print(f"结果图片已保存到: {output_file}") + return index + + +def top_object_index(objectnesses, threshold): + for i, objectness in enumerate(objectnesses): + if objectness < threshold: + continue + else: + return i + + + + +def boxes_filter(pred_bboxes, raw_bboxes, pred_scores, instances): + # Step 1: Filter by pred_scores + filtered_indices = [i for i, score in enumerate(pred_scores) if score >= 0.97] + + pred_bboxes = [pred_bboxes[i] for i in filtered_indices] + raw_bboxes = [raw_bboxes[i] for i in filtered_indices] + pred_scores = [pred_scores[i] for i in filtered_indices] + instances = [instances[i] for i in filtered_indices] + + # Step 2: Filter by IoU + keep_indices = set(range(len(pred_bboxes))) + for i in range(len(pred_bboxes)): + if i not in keep_indices: + continue + for j in range(i + 1, len(pred_bboxes)): + if j not in keep_indices: + continue + iou = get_iou(pred_bboxes[i], pred_bboxes[j]) + if iou > 0.9: + if pred_scores[i] >= pred_scores[j]: + keep_indices.discard(j) + else: + keep_indices.discard(i) + + pred_bboxes = [pred_bboxes[i] for i in sorted(keep_indices)] + raw_bboxes = [raw_bboxes[i] for i in sorted(keep_indices)] + pred_scores = [pred_scores[i] for i in sorted(keep_indices)] + instances = [instances[i] for i in sorted(keep_indices)] + + # Step 3: Filter by duplicate instances + instance_map = {} + for i in range(len(instances)): + instance = instances[i] + if instance not in instance_map or pred_scores[i] > pred_scores[instance_map[instance]]: + instance_map[instance] = i + + unique_indices = sorted(instance_map.values()) + pred_bboxes = [pred_bboxes[i] for i in unique_indices] + raw_bboxes = [raw_bboxes[i] for i in unique_indices] + pred_scores = [pred_scores[i] for i in unique_indices] + instances = [instances[i] for i in unique_indices] + + return pred_bboxes, raw_bboxes, pred_scores, instances + + + +def format_string(input_string: str) -> str: + # 大写 转 小写 + lowercased = input_string.lower() + # 空格 转 下划线 + transformed = re.sub(r"\s+", "_", lowercased) # \s+ 匹配一个或多个空白字符 + return transformed \ No newline at end of file diff --git a/owlv2_img_embeding.py b/owlv2_img_embeding.py new file mode 100644 index 0000000000000000000000000000000000000000..170312b227e391d4069e3d8d008dc3c81ad8cd5d --- /dev/null +++ b/owlv2_img_embeding.py @@ -0,0 +1,180 @@ +import os +import sys +import json + +# pip install ott-jax==0.2.0 +import jax +import numpy as np +import tensorflow as tf +from scipy.special import expit as sigmoid + +import skimage +from skimage import io as skimage_io +from skimage import transform as skimage_transform +import matplotlib as mpl +from matplotlib import pyplot as plt + +sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision') +tf.config.experimental.set_visible_devices([], 'GPU') + +from scenic.projects.owl_vit import configs +from scenic.projects.owl_vit import models + +# from owlv2_helper_functions import prepare_images +from owlv2_helper_functions import read_images, preprocess_images +from owlv2_helper_functions import plot_bbox_on_image, image_based_plot_boxes_on_image +from owlv2_helper_functions import top_object_index +from owlv2_helper_functions import rescale_detection_box +from owlv2_helper_functions import get_iou, boxes_filter + + + +""" +Prepare OWLv2 pretrained model +""" +config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint') +module = models.TextZeroShotDetectionModule( + body_configs=config.model.body, + objectness_head_configs=config.model.objectness_head, + normalize=config.model.normalize, + box_bias=config.model.box_bias) +variables = module.load_variables(config.init_from.checkpoint_path) + + + + +""" +Wrapped model components +""" +import functools + +image_embedder = jax.jit( + functools.partial(module.apply, variables, train=False, method=module.image_embedder)) + +objectness_predictor = jax.jit( + functools.partial(module.apply, variables, method=module.objectness_predictor)) + +box_predictor = jax.jit( + functools.partial(module.apply, variables, method=module.box_predictor)) + +class_predictor = jax.jit( + functools.partial(module.apply, variables, method=module.class_predictor)) + + + + +""" +Detect the main object on instances' images +""" +INSTANCE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances' +INSTANCE_DETECTION = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_detections' + +model_input_size = config.dataset_configs.input_size +images, source_images_names = read_images(INSTANCE_DIR) +source_images = preprocess_images(images, model_input_size) + + +instances, query_embeddings, indexes = [], [], [] +for source_image, source_image_name in zip(source_images, source_images_names): + feature_map = image_embedder(source_image[None, ...]) + b, h, w, d = feature_map.shape + image_features = feature_map.reshape(b, h * w, d) + + objectnesses = objectness_predictor(image_features)['objectness_logits'] + bboxes = box_predictor(image_features=image_features, feature_map=feature_map)['pred_boxes'] + all_class_embeddings = class_predictor(image_features=image_features)['class_embeddings'] + + # Remove batch dimension + objectnesses = np.array(objectnesses[0]) + bboxes = np.array(bboxes[0]) + all_class_embeddings = np.array(all_class_embeddings[0]) + + top_k = 1 + objectnesses = sigmoid(objectnesses) + objectness_threshold = np.partition(objectnesses, -top_k)[-top_k] + + index = top_object_index(objectnesses, objectness_threshold) + query_embedding = all_class_embeddings[index] + + indexes.append(index) + instances.append(source_image_name.split('_')[0]) + query_embeddings.append(query_embedding) + + # Plot instance detection + output_file = os.path.join(INSTANCE_DETECTION, source_image_name) + plot_bbox_on_image(source_image, bboxes, objectnesses, objectness_threshold, output_file) + + + + +# IMAGE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/data_sample' +# OUTPUT_DIR = '/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/bliu75_output/test_output/batch_results' +IMAGE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_video08' +OUTPUT_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_video08_detect' + +digital_twins = {} + +images, target_images_names = read_images(IMAGE_DIR) +target_images = preprocess_images(images, model_input_size) +h, w, d = images[0].shape +size = max(h, w) + +for target_image, target_image_name, image in zip(target_images, target_images_names, images): + + feature_map = image_embedder(target_image[None, ...]) + b, h, w, d = feature_map.shape + image_features = feature_map.reshape(b, h * w, d) + all_bboxes = box_predictor(image_features=image_features, feature_map=feature_map)['pred_boxes'] + + pred_scores, pred_bboxes = [], [] + for i, query_embedding in enumerate(query_embeddings): + target_class_predictions = class_predictor( + image_features=feature_map.reshape(b, h * w, d), + query_embeddings=query_embedding[None, None, ...], # [batch, queries, d] + ) + + # Remove batch dimension and convert to numpy: + logits = np.array(target_class_predictions['pred_logits'][0]) + bboxes = np.array(all_bboxes[0]) + + top_ind = np.argmax(logits[:, 0], axis=0) + score = logits[top_ind, 0] + bbox = bboxes[top_ind] + + pred_bboxes.append(bbox) + pred_scores.append(score) + + instances_dup = instances[:] + pred_scores = sigmoid(pred_scores) + rescaled_bboxes = rescale_detection_box(pred_bboxes, image) + + rescaled_bboxes, pred_bboxes, pred_scores, instances_dup = boxes_filter(rescaled_bboxes, pred_bboxes, pred_scores, instances_dup) + + + count = {} + for instance_name in instances_dup: + count[instance_name] = 0 + + digital_twins[target_image_name] = {} + for instance_i, (instance_name, instance_box, instance_raw_box, instance_score) in enumerate(zip(instances_dup, rescaled_bboxes, pred_bboxes, pred_scores)): + x1, y1, x2, y2 = map(float, instance_box) + cx, cy, box_w, box_h = map(float, instance_raw_box) + x = round(instance_score, 2) + + digital_twins[target_image_name][f'{instance_name}_{count[instance_name]}'] = { + 'detection_label': instance_name, + 'detection_box': [x1, y1, x2, y2], + 'detection_centroid': [cx*size, cy*size], + 'detection_score': round(float(instance_score), 2), + } + count[instance_name] += 1 + + + image_based_plot_boxes_on_image(image, instances_dup, pred_scores, rescaled_bboxes, target_image_name, OUTPUT_DIR) + + +JSON_OUT_PATH = "/home/netzone22/bohanliu_2025/DT_SPR_video08_detection.json" +if not os.path.exists(JSON_OUT_PATH): + os.makedirs(JSON_OUT_PATH) +with open(JSON_OUT_PATH, "w", encoding="utf-8") as json_f: + json.dump(digital_twins, json_f, indent=4) \ No newline at end of file diff --git a/owlv2_img_embeding_2.py b/owlv2_img_embeding_2.py new file mode 100644 index 0000000000000000000000000000000000000000..96fb0da8cf201cdef78cb47ae0bdaff1724f0642 --- /dev/null +++ b/owlv2_img_embeding_2.py @@ -0,0 +1,148 @@ +import os +import sys +import json + +# pip install ott-jax==0.2.0 +import jax +import numpy as np +import tensorflow as tf +from scipy.special import expit as sigmoid + +import skimage +from skimage import io as skimage_io +from skimage import transform as skimage_transform +import matplotlib as mpl +from matplotlib import pyplot as plt + +sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision') +tf.config.experimental.set_visible_devices([], 'GPU') + +from scenic.projects.owl_vit import configs +from scenic.projects.owl_vit import models + +# from owlv2_helper_functions import prepare_images +from owlv2_helper_functions import read_images, preprocess_images +from owlv2_helper_functions import plot_bbox_on_image, image_based_plot_boxes_on_image, plot_boxes_on_image +from owlv2_helper_functions import top_object_index +from owlv2_helper_functions import rescale_detection_box + + + + +""" +Prepare OWLv2 pretrained model +""" +config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint') +module = models.TextZeroShotDetectionModule( + body_configs=config.model.body, + objectness_head_configs=config.model.objectness_head, + normalize=config.model.normalize, + box_bias=config.model.box_bias) +variables = module.load_variables(config.init_from.checkpoint_path) + + + + +""" +Wrapped model components +""" +import functools + +image_embedder = jax.jit( + functools.partial(module.apply, variables, train=False, method=module.image_embedder)) +objectness_predictor = jax.jit( + functools.partial(module.apply, variables, method=module.objectness_predictor)) +box_predictor = jax.jit( + functools.partial(module.apply, variables, method=module.box_predictor)) +class_predictor = jax.jit( + functools.partial(module.apply, variables, method=module.class_predictor)) + + + + +""" +Detect the main object on instances' images +""" +INSTANCE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_0' +INSTANCE_DETECTION = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_detections_0' + +model_input_size = config.dataset_configs.input_size +images, source_images_names = read_images(INSTANCE_DIR) +source_images = preprocess_images(images, model_input_size) + +feature_map = image_embedder(source_images) +b, h, w, d = feature_map.shape +image_features = feature_map.reshape(b, h * w, d) + +objectnesses = objectness_predictor(image_features)['objectness_logits'] +bboxes = box_predictor(image_features=image_features, feature_map=feature_map)['pred_boxes'] +source_class_embeddings = class_predictor(image_features=image_features)['class_embeddings'] + +# print(f"Debug: source instance detection") +# print(f" Source images features shape: {image_features.shape}") +# print(f" objectnesses shape: {objectnesses.shape}") +# print(f" bboxes shape: {bboxes.shape}") +# print(f" source_class_embeddings shape: {source_class_embeddings.shape}") + +objectnesses = sigmoid(objectnesses) +top_objectnesses = np.max(objectnesses, axis=1) + +instances, query_embeddings, indexes = [], [], [] +for i in range(len(source_images_names)): + index = top_object_index(objectnesses[i], top_objectnesses[i]) + query_embedding = source_class_embeddings[index] + + indexes.append(index) + instances.append(source_images_names[i].split('_')[0]) + query_embeddings.append(query_embedding) + + output_file = os.path.join(INSTANCE_DETECTION, source_images_names[i]) + plot_bbox_on_image(source_images[i], bboxes[i], objectnesses[i], top_objectnesses[i], output_file) + + + + + +IMAGE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/data_sample' +OUTPUT_DIR = '/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/bliu75_output/test_output/batch_results' + +images, target_images_names = read_images(IMAGE_DIR) +target_images = preprocess_images(images, model_input_size) + +for target_image, target_image_name, image in zip(target_images, target_images_names, images): + + feature_map = image_embedder(target_image[None, ...]) + b, h, w, d = feature_map.shape + target_boxes = box_predictor(image_features=feature_map.reshape(b, h * w, d), feature_map=feature_map)['pred_boxes'] + + target_class_predictions = class_predictor( + image_features=feature_map.reshape(b, h * w, d), + query_embeddings=query_embedding[None, ...], # [batch, queries, d] + ) + + logits = np.array(target_class_predictions['pred_logits'][0]) + raw_boxes = np.array(target_boxes[0]) + + top_ind = np.argmax(logits[:, 0], axis=0) + score = sigmoid(logits[top_ind, 0]) + + # labels = np.argmax(target_class_predictions['pred_logits'][0], axis=-1) + # scores = sigmoid(np.max(logits, axis=-1)) + + boxes = rescale_detection_box(raw_boxes, image) + boxes = boxes[top_ind] + + score = np.array([score]) + boxes = np.array([boxes]) + + image_based_plot_boxes_on_image(image, instances, score, boxes, target_image_name, OUTPUT_DIR) + + print(f"Debug: traget instance detection") + # print(f" target_class_predictions' keys: {target_class_predictions.keys()}") + print(f" target_logits: {logits.shape}") + print(logits) + # print(f" target_scores: {scores.shape}") + # print(f" target_labels: {labels.shape}") + # print(f" target_boxes shape: {raw_boxes.shape}") + + # plot_boxes_on_image(image, instances, scores, boxes, labels, target_image_name, 0.5, OUTPUT_DIR) diff --git a/owlv2_inference.py b/owlv2_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ed9fa343c89d5c8b7c1dbd95b2a6dd106dad28 --- /dev/null +++ b/owlv2_inference.py @@ -0,0 +1,176 @@ +import os +import sys +import json + +# pip install ott-jax==0.2.0 +import jax +import numpy as np +import tensorflow as tf +from scipy.special import expit as sigmoid + +import skimage +from skimage import io as skimage_io +from skimage import transform as skimage_transform +import matplotlib as mpl +from matplotlib import pyplot as plt + +sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision') +tf.config.experimental.set_visible_devices([], 'GPU') + +from scenic.projects.owl_vit import configs +from scenic.projects.owl_vit import models + +from owlv2_helper_functions import plot_boxes_on_image, rescale_detection_box +from owlv2_helper_functions import format_string + + + +# IMAGE_DIR = '/home/netzone22/bohanliu_2025/HALO/case28/case_28' +# OUTPUT_DIR = '/home/netzone22/bohanliu_2025/HALO/case28/case_28_detection' +# JOSN_IN = '/home/netzone22/bohanliu_2025/structured_prompt.json' +# JSON_OUT = '/home/netzone22/bohanliu_2025/HALO/case28/case_28_detection.json' + + +# IMAGE_DIR = '/home/netzone22/bohanliu_2025/HALO_test/single_image' +# OUTPUT_DIR = '/home/netzone22/bohanliu_2025/HALO_test/single_image/detection' +# JOSN_IN = '/home/netzone22/bohanliu_2025/structured_prompt.json' +# JSON_OUT = '/home/netzone22/bohanliu_2025/HALO_test/single_image/detection.json' + +INSTANCE = 'machine' +IMAGE_DIR = f'/home/netzone22/bohanliu_2025/HALO_test/semantic/{INSTANCE}' +OUTPUT_DIR = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/detection' +JOSN_IN = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/structured_prompt.json' +JSON_OUT = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/metadata.json' + + +THRESHOLD = 0.12 + +TEXT = [] +with open(JOSN_IN, 'r') as file: + prompts = json.load(file) + +for target_obj in prompts['target_obj']: + TEXT.append(format_string(target_obj)) +if prompts['spacial_info'] == True: + for referred_obj in prompts['referred_obj'].keys(): + TEXT.append(format_string(referred_obj)) +print(f"\nQueries: {TEXT}\n") + + + +### Choose config +# config = configs.owl_v2_clip_b16.get_config(init_mode='canonical_checkpoint') +config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint') + + +### Load the model and variables +module = models.TextZeroShotDetectionModule( + body_configs=config.model.body, + objectness_head_configs=config.model.objectness_head, + normalize=config.model.normalize, + box_bias=config.model.box_bias) + +variables = module.load_variables(config.init_from.checkpoint_path) + + + +### Prepare text queries +text_queries = TEXT # ['machine', 'human'] +tokenized_queries = np.array([ + module.tokenize(q, config.dataset_configs.max_query_length) + for q in text_queries +]) +# Pad tokenized queries to avoid recompilation if number of queries changes: +tokenized_queries = np.pad( + tokenized_queries, + pad_width=((0, 100 - len(text_queries)), (0, 0)), + constant_values=0) + + + +### Prepare image +jitted = jax.jit(module.apply, static_argnames=('train',)) +digital_twins = {} +# filenames = sorted(tf.io.gfile.listdir(IMAGE_DIR)) + +extensions = {".jpg", ".jpeg", ".png"} +filenames = sorted([ + file + for file in tf.io.gfile.listdir(IMAGE_DIR) + if any(file.lower().endswith(ext) for ext in extensions) +]) + +for i, filename in enumerate(filenames): + file_path = os.path.join(IMAGE_DIR, filename) + image_uint8 = skimage_io.imread(file_path) + image = image_uint8.astype(np.float32) / 255.0 + # Pad to square with gray pixels on bottom and right: + h, w, _ = image.shape + # print(f"original img: {h} x {w}") + size = max(h, w) + image_padded = np.pad(image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5) + # Resize to model input size: + input_image = skimage.transform.resize( + image_padded, + (config.dataset_configs.input_size, config.dataset_configs.input_size), + anti_aliasing=True + ) + + + ### Get predictions + # This will take a minute on the first execution due to model compilation. + # Subsequent executions will be faster. + # jitted = jax.jit(module.apply, static_argnames=('train',)) + + # Note: The model expects a batch dimension. + predictions = jitted( + variables, + input_image[None, ...], + tokenized_queries[None, ...], + train=False) + # Remove batch dimension and convert to numpy: + predictions = jax.tree_util.tree_map(lambda x: np.array(x[0]), predictions) + # print(predictions.keys()) + + ### Plot prediction + score_threshold = THRESHOLD # 0.1 + + logits = predictions['pred_logits'][..., :len(text_queries)] # Remove padding. + scores = sigmoid(np.max(logits, axis=-1)) + labels = np.argmax(predictions['pred_logits'], axis=-1) + raw_boxes = predictions['pred_boxes'] + boxes = rescale_detection_box(raw_boxes, image) + + + + ### Write results into JSON file. + digital_twins[filename] = {} + + count = {} + for label in labels: + count[text_queries[label]] = 0 + + for score, raw_box, box, label in zip(scores, raw_boxes, boxes, labels): + if score < score_threshold: + continue; + + # x1, y1, x2, y2 = box + x1, y1, x2, y2 = map(float, box) + cx, cy, box_w, box_h = map(float, raw_box) + x = round(score, 2) + + digital_twins[filename][f'{text_queries[label]}_{count[text_queries[label]]}'] = { + 'detection_label': text_queries[label], + 'detection_box': [x1, y1, x2, y2], + 'detection_centroid': [cx*size, cy*size], + 'detection_score': round(float(score), 2), + } + count[text_queries[label]]+=1 + + if not os.path.exists(OUTPUT_DIR): + os.makedirs(OUTPUT_DIR) + plot_boxes_on_image(image, text_queries, scores, boxes, labels, filename, score_threshold, OUTPUT_DIR) + + +with open(JSON_OUT, "w", encoding="utf-8") as json_f: + json.dump(digital_twins, json_f, indent=4) diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000000000000000000000000000000000000..98037fc9c3181c961927ee0034afa8ef616e32e6 --- /dev/null +++ b/pylintrc @@ -0,0 +1,372 @@ +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=1 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code +extension-pkg-whitelist= + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. See also the "--disable" option for examples. +enable=use-symbolic-message-instead,useless-supression,fixme + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" + +disable= + attribute-defined-outside-init, + duplicate-code, + # invalid-name, + # missing-docstring, + protected-access, + too-few-public-methods, + # handled by black + format + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME,XXX,TODO + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=_$|dummy + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=80 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + +# List of optional constructs for which whitespace checking is disabled +no-space-check=trailing-comma,dict-separator + +# Maximum number of lines in a module +max-module-lines=2000 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[BASIC] + +# List of builtins function names that should not be used, separated by a comma +bad-functions=map,filter,input + +# Good variable names which should always be accepted, separated by a comma +good-names=i,j,k,ex,Run,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names=foo,bar,baz,toto,tutu,tata + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# Regular expression matching correct function names +function-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for function names +function-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct variable names +variable-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for variable names +variable-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct constant names +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Naming hint for constant names +const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Regular expression matching correct attribute names +attr-rgx=[a-z_][a-z0-9_]{2,}$ + +# Naming hint for attribute names +attr-name-hint=[a-z_][a-z0-9_]{2,}$ + +# Regular expression matching correct argument names +argument-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for argument names +argument-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Naming hint for class attribute names +class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ + +# Naming hint for inline iteration names +inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=[A-Z_][a-zA-Z0-9]+$ + +# Naming hint for class names +class-name-hint=[A-Z_][a-zA-Z0-9]+$ + +# Regular expression matching correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Naming hint for module names +module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression matching correct method names +method-rgx=[a-z_][a-z0-9_]{2,}$ + +# Naming hint for method names +method-name-hint=[a-z_][a-z0-9_]{2,}$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=__.*__ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# List of decorators that define properties, such as abc.abstractproperty. +property-classes=abc.abstractproperty + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis +ignored-modules= + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members=REQUEST,acl_users,aq_parent + +# List of decorators that create context managers from functions, such as +# contextlib.contextmanager. +contextmanager-decorators=contextlib.contextmanager + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=10 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=25 + +# Maximum number of return / yield for function / method body +max-returns=11 + +# Maximum number of branch for function / method body +max-branches=26 + +# Maximum number of statements in function / method body +max-statements=100 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=11 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=25 + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp,__post_init__ + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception diff --git a/scenic.egg-info/PKG-INFO b/scenic.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..e33311b418782a60e7e5517671cab875d3b4c3c2 --- /dev/null +++ b/scenic.egg-info/PKG-INFO @@ -0,0 +1,254 @@ +Metadata-Version: 2.1 +Name: scenic +Version: 0.0.1 +Summary: A Jax Library for Computer Vision Research and Beyond. +Home-page: http://github.com/google-research/scenic +Author: Scenic Authors +Author-email: no-reply@google.com +License: Apache 2.0 +Keywords: Scenic +Classifier: Development Status :: 1 - Beta +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Science/Research +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Programming Language :: Python +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: absl-py>=1.0.0 +Requires-Dist: numpy>=1.12 +Requires-Dist: jax>=0.4.3 +Requires-Dist: jaxlib>=0.4.3 +Requires-Dist: flax>=0.4.0 +Requires-Dist: ml-collections>=0.1.1 +Requires-Dist: tensorflow>=2.7 +Requires-Dist: immutabledict>=2.2.1 +Requires-Dist: clu>=0.0.6 +Requires-Dist: tensorflow-datasets +Requires-Dist: optax@ git+https://github.com/google-deepmind/optax.git@main +Provides-Extra: testing +Requires-Dist: pytest; extra == "testing" +Requires-Dist: shapely; extra == "testing" +Requires-Dist: ott-jax>=0.2.0; extra == "testing" +Requires-Dist: sklearn; extra == "testing" +Requires-Dist: lingvo==0.12.6; extra == "testing" +Requires-Dist: seaborn>=0.11.2; extra == "testing" +Requires-Dist: dmvr@ git+https://github.com/google-deepmind/dmvr.git ; extra == "testing" + +# Scenic +
+scenic logo +
+ +*Scenic* is a codebase with a focus on research around attention-based models +for computer vision. Scenic has been successfully used to develop +classification, segmentation, and detection models for multiple modalities +including images, video, audio, and multimodal combinations of them. + +More precisely, *Scenic* is a (i) set of shared light-weight libraries solving +tasks commonly encountered tasks when training large-scale (i.e. multi-device, +multi-host) vision models; and (ii) several *projects* containing fully +fleshed out problem-specific training and evaluation loops using these +libraries. + +Scenic is developed in [JAX](https://github.com/jax-ml/jax) and uses +[Flax](https://github.com/google/flax). + +### Contents +* [What we offer](#what-we-offer) +* [SOTA models and baselines in Scenic](#sota-models-and-baselines-in-scenic) +* [Philosophy](#philosophy) +* [Getting started](#getting-started) +* [Scenic component design](#scenic-component-design) +* [Citing Scenic](#citing-scenic) + +## What we offer +Among others *Scenic* provides + +* Boilerplate code for launching experiments, summary writing, logging, + profiling, etc; +* Optimized training and evaluation loops, losses, metrics, bi-partite matchers, + etc; +* Input-pipelines for popular vision datasets; +* [Baseline models](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines#scenic-baseline-models), +including strong non-attentional baselines. + + +## SOTA models and baselines in *Scenic* +There are some SOTA models and baselines in Scenic which were either developed +using Scenic, or have been reimplemented in Scenic: + +Projects that were developed in Scenic or used it for their experiments: + +* [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) +* [OmniNet: Omnidirectional Representations from Transformers](https://arxiv.org/abs/2103.01075) +* [Attention Bottlenecks for Multimodal Fusion](https://arxiv.org/abs/2107.00135) +* [TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?](https://arxiv.org/abs/2106.11297) +* [Exploring the Limits of Large Scale Pre-training](https://arxiv.org/abs/2110.02095) +* [The Efficiency Misnomer](https://arxiv.org/abs/2110.12894) +* [Discrete Representations Strengthen Vision Transformer Robustness](https://arxiv.org/abs/2111.10493) +* [Pyramid Adversarial Training Improves ViT Performance](https://arxiv.org/abs/2111.15121) +* [VUT: Versatile UI Transformer for Multi-Modal Multi-Task User Interface Modeling](https://arxiv.org/abs/2112.05692) +* [CLAY: Learning to Denoise Raw Mobile UI Layouts for Improving Datasets at Scale](https://arxiv.org/abs/2201.04100) +* [Zero-Shot Text-Guided Object Generation with Dream Fields](https://arxiv.org/abs/2112.01455) +* [Multiview Transformers for Video Recognition](https://arxiv.org/abs/2201.04288) +* [PolyViT: Co-training Vision Transformers on Images, Videos and Audio](https://arxiv.org/abs/2111.12993) +* [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) +* [Learning with Neighbor Consistency for Noisy Labels](https://arxiv.org/abs/2202.02200) +* [Token Turing Machines](https://arxiv.org/pdf/2211.09119.pdf) +* [Vid2Seq: Large-Scale Pretraining of a Visual Language Model for Dense Video Captioning](https://arxiv.org/pdf/2302.14115.pdf) +* [AVATAR: Unconstrained Audiovisual Speech Recognition](https://arxiv.org/abs/2206.07684) +* [Adaptive Computation with Elastic Input Sequence](https://arxiv.org/abs/2301.13195) +* [Location-Aware Self-Supervised Transformers for Semantic Segmentation](https://arxiv.org/abs/2212.02400) +* [How can objects help action recognition?](https://openaccess.thecvf.com/content/CVPR2023/html/Zhou_How_Can_Objects_Help_Action_Recognition_CVPR_2023_paper.html) +* [Verbs in Action: Improving verb understanding in video-language models](https://arxiv.org/abs/2304.06708) +* [Unified Visual Relationship Detection with Vision and Language Models](https://arxiv.org/abs/2303.08998) +* [UnLoc: A Unified Framework for Video Localization Tasks](https://arxiv.org/abs/2308.11062) +* [REVEAL: Retrieval-Augmented Visual-Language Pre-Training with Multi-Source Multimodal Knowledge Memory](https://arxiv.org/abs/2212.05221) +* [Audiovisual Masked Autoencoders](https://arxiv.org/abs/2212.05922) +* [MatFormer: Nested Transformer for Elastic Inference](https://arxiv.org/abs/2310.07707) +* [Pixel Aligned Language Models](https://arxiv.org/abs/2312.09237) +* [A Generative Approach for Wikipedia-Scale Visual Entity Recognition](https://arxiv.org/abs/2403.02041) +* [Streaming Dense Video Captioning](https://arxiv.org/abs/2404.01297) +* [Dense Video Object Captioning from Disjoint Supervision](https://arxiv.org/abs/2306.11729) + +More information can be found in [projects](https://github.com/google-research/scenic/tree/main/scenic/projects#list-of-projects-hosted-in-scenic). + +Baselines that were reproduced in Scenic: + +* [(ViT) An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) +* [(DETR) End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) +* [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159) +* [(CLIP) Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) +* [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601) +* [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) +* [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270) +* [Big Transfer (BiT): General Visual Representation Learning](https://arxiv.org/abs/1912.11370) +* [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) +* [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) +* [PCT: Point Cloud Transformer](https://arxiv.org/abs/2012.09688) +* [Universal Transformers](https://arxiv.org/abs/1807.03819) +* [PonderNet](https://arxiv.org/abs/2107.05407) +* [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) +* [Rethinking Attention with Performers](https://arxiv.org/abs/2009.14794) +* [(CenterNet) Objects as Points](https://arxiv.org/abs/1904.07850) +* [(SAM) Segment Anything](https://arxiv.org/abs/2304.02643) + + +More information can be found in [baseline models](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines#scenic-baseline-models). + +
+## Philosophy +*Scenic* aims to facilitate rapid prototyping of large-scale vision models. To +keep the code simple to understand and extend we prefer *forking and +copy-pasting over adding complexity or increasing abstraction*. Only when +functionality proves to be widely useful across many models and tasks it may be +upstreamed to Scenic's shared libraries. + + + +## Getting started +* See `projects/baselines/README.md` for a walk-through baseline models and + instructions on how to run the code. +* If you would like to contribute to *Scenic*, please check out the + [Philisophy](#philosophy), [Code structure](#code_structure) and + [Contributing](CONTRIBUTING.md) sections. + Should your contribution be a part of the shared libraries, please send us a + pull request! + + +### Quickstart +You will need Python 3.9 or later. Download the code from GitHub + +```shell +$ git clone https://github.com/google-research/scenic.git +$ cd scenic +$ pip install . +``` + +and run training for ViT on ImageNet: + +```shell +$ python scenic/main.py -- \ + --config=scenic/projects/baselines/configs/imagenet/imagenet_vit_config.py \ + --workdir=./ +``` + +Note that for specific projects and baselines, you might need to install extra +packages that are mentioned in their `README.md` or `requirements.txt` files. + +[Here](https://colab.research.google.com/github/google-research/scenic/blob/main/scenic/common_lib/colabs/scenic_playground.ipynb) +is also a minimal colab to train a simple feed-forward model using Scenic. + + +## Scenic component design +Scenic is designed to propose different levels of abstraction, to support +hosting projects that only require changing hyper-parameters by defining config +files, to those that need customization on the input pipeline, model +architecture, losses and metrics, and the training loop. To make this happen, +the code in Scenic is organized as either _project-level_ code, +which refers to customized code for specific projects or baselines or +_library-level_ code, which refers to common functionalities and general +patterns that are adapted by the majority of projects. The project-level +code lives in the `projects` directory. + +
+scenic design +
+ +### Library-level code +The goal is to keep the library-level code minimal and well-tested and to avoid +introducing extra abstractions to support minor use-cases. Shared libraries +provided by *Scenic* are split into: + +* `dataset_lib`: Implements IO pipelines for loading and pre-processing data + for common Computer Vision tasks and benchmarks (see "Tasks and Datasets" + section). All pipelines are designed to be scalable and support multi-host + and multi-device setups, taking care dividing data among multiple hosts, + incomplete batches, caching, pre-fetching, etc. +* `model_lib` : Provides + * several abstract model interfaces (e.g. `ClassificationModel` or + `SegmentationModel` in `model_lib.base_models`) with task-specific + losses and metrics; + * neural network layers in `model_lib.layers`, focusing on efficient + implementation of attention and transformer layers; + * accelerator-friendly implementations of bipartite matching + algorithms in `model_lib.matchers`. +* `train_lib`: Provides tools for constructing training loops and implements + several optimized trainers (classification trainer and segmentation trainer) + that can be forked for customization. +* `common_lib`: General utilities, like logging and debugging modules, + functionalities for processing raw data, etc. + +### Project-level code +Scenic supports the development of customized solutions for customized tasks and +data via the concept of "project". There is no one-fits-all recipe for how much +code should be re-used by a project. Projects can consist of only configs and +use the common models, trainers, task/data that live in library-level code, or +they can simply fork any of the mentioned functionalities and redefine, layers, +losses, metrics, logging methods, tasks, architectures, as well as training and +evaluation loops. The modularity of library-level code makes it flexible for +projects to fall placed on any spot in the "run-as-is" to "fully customized" +spectrum. + +Common baselines such as a ResNet and Vision Transformer (ViT) are implemented +in the [`projects/baselines`](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines) +project. Forking models in this directory is a good starting point for new +projects. + + +## Citing Scenic +If you use Scenic, you can cite our [white paper](https://openaccess.thecvf.com/content/CVPR2022/html/Dehghani_Scenic_A_JAX_Library_for_Computer_Vision_Research_and_Beyond_CVPR_2022_paper.html). +Here is an example BibTeX entry: + +```bibtex +@InProceedings{dehghani2021scenic, + author = {Dehghani, Mostafa and Gritsenko, Alexey and Arnab, Anurag and Minderer, Matthias and Tay, Yi}, + title = {Scenic: A JAX Library for Computer Vision Research and Beyond}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2022}, + pages = {21393-21398} +} +``` + +_Disclaimer: This is not an official Google product._ diff --git a/scenic.egg-info/SOURCES.txt b/scenic.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..7c4d9f943c65b324da6ee008bb806439d22788f6 --- /dev/null +++ b/scenic.egg-info/SOURCES.txt @@ -0,0 +1,700 @@ +LICENSE +README.md +setup.py +scenic/__init__.py +scenic/app.py +scenic/main.py +scenic.egg-info/PKG-INFO +scenic.egg-info/SOURCES.txt +scenic.egg-info/dependency_links.txt +scenic.egg-info/requires.txt +scenic.egg-info/top_level.txt +scenic/common_lib/__init__.py +scenic/common_lib/common_utils.py +scenic/common_lib/debug_utils.py +scenic/common_lib/export_utils.py +scenic/common_lib/image_utils.py +scenic/common_lib/video_utils.py +scenic/common_lib/colabs/__init__.py +scenic/common_lib/tests/__init__.py +scenic/common_lib/tests/test_common_utils.py +scenic/common_lib/tests/test_debug_utils.py +scenic/common_lib/tests/test_image_utils.py +scenic/common_lib/tests/test_video_utils.py +scenic/dataset_lib/__init__.py +scenic/dataset_lib/bair_dataset.py +scenic/dataset_lib/cifar10_dataset.py +scenic/dataset_lib/cityscapes_dataset.py +scenic/dataset_lib/dataset_utils.py +scenic/dataset_lib/datasets.py +scenic/dataset_lib/fashion_mnist_dataset.py +scenic/dataset_lib/imagenet_dataset.py +scenic/dataset_lib/mnist_dataset.py +scenic/dataset_lib/oxford_pets_dataset.py +scenic/dataset_lib/svhn_dataset.py +scenic/dataset_lib/video_ops.py +scenic/dataset_lib/big_transfer/__init__.py +scenic/dataset_lib/big_transfer/bit.py +scenic/dataset_lib/big_transfer/builder.py +scenic/dataset_lib/big_transfer/registry.py +scenic/dataset_lib/big_transfer/preprocessing/__init__.py +scenic/dataset_lib/big_transfer/preprocessing/autoaugment.py +scenic/dataset_lib/big_transfer/preprocessing/ops.py +scenic/dataset_lib/big_transfer/preprocessing/utils.py +scenic/dataset_lib/big_transfer/preprocessing/vtab_ops.py +scenic/dataset_lib/coco_dataset/__init__.py +scenic/dataset_lib/coco_dataset/coco_eval.py +scenic/dataset_lib/coco_dataset/coco_utils.py +scenic/dataset_lib/coco_dataset/data/__init__.py +scenic/dataset_lib/coco_dataset/data/images/__init__.py +scenic/dataset_lib/coco_dataset/tests/__init__.py +scenic/dataset_lib/coco_dataset/tests/test_coco_utils.py +scenic/dataset_lib/tests/__init__.py +scenic/dataset_lib/tests/test_dataset_utils.py +scenic/model_lib/__init__.py +scenic/model_lib/models.py +scenic/model_lib/base_models/__init__.py +scenic/model_lib/base_models/base_model.py +scenic/model_lib/base_models/box_utils.py +scenic/model_lib/base_models/classification_model.py +scenic/model_lib/base_models/encoder_decoder_model.py +scenic/model_lib/base_models/model_utils.py +scenic/model_lib/base_models/multilabel_classification_model.py +scenic/model_lib/base_models/regression_model.py +scenic/model_lib/base_models/segmentation_model.py +scenic/model_lib/base_models/tests/__init__.py +scenic/model_lib/base_models/tests/test_box_utils.py +scenic/model_lib/base_models/tests/test_classification_model.py +scenic/model_lib/base_models/tests/test_encoder_decoder_model.py +scenic/model_lib/base_models/tests/test_model_utils.py +scenic/model_lib/base_models/tests/test_multilabel_classification_model.py +scenic/model_lib/base_models/tests/test_regression_model.py +scenic/model_lib/base_models/tests/test_segmentation_model.py +scenic/model_lib/layers/__init__.py +scenic/model_lib/layers/attention_layers.py +scenic/model_lib/layers/masked_layers.py +scenic/model_lib/layers/nn_layers.py +scenic/model_lib/layers/nn_ops.py +scenic/model_lib/layers/tests/__init__.py +scenic/model_lib/layers/tests/test_attention_layers.py +scenic/model_lib/layers/tests/test_masked_layers.py +scenic/model_lib/layers/tests/test_nn_layers.py +scenic/model_lib/layers/tests/test_nn_ops.py +scenic/model_lib/matchers/__init__.py +scenic/model_lib/matchers/common.py +scenic/model_lib/matchers/greedy.py +scenic/model_lib/matchers/hungarian.py +scenic/model_lib/matchers/hungarian_cover.py +scenic/model_lib/matchers/hungarian_jax.py +scenic/model_lib/matchers/lazy.py +scenic/model_lib/matchers/sinkhorn.py +scenic/model_lib/matchers/tests/__init__.py +scenic/model_lib/matchers/tests/test_matchers.py +scenic/model_lib/tests/__init__.py +scenic/model_lib/tests/test_models.py +scenic/projects/__init__.py +scenic/projects/adatape/__init__.py +scenic/projects/adatape/layers.py +scenic/projects/adatape/main.py +scenic/projects/adatape/adatape_vit/__init__.py +scenic/projects/adatape/adatape_vit/adatape_classify_trainer.py +scenic/projects/adatape/adatape_vit/adatape_trainer.py +scenic/projects/adatape/adatape_vit/adatape_vit.py +scenic/projects/adatape/dataset/__init__.py +scenic/projects/adatape/dataset/parity_dataset.py +scenic/projects/baselines/__init__.py +scenic/projects/baselines/axial_resnet.py +scenic/projects/baselines/bit_resnet.py +scenic/projects/baselines/fully_connected.py +scenic/projects/baselines/hybrid_vit.py +scenic/projects/baselines/mixer.py +scenic/projects/baselines/resnet.py +scenic/projects/baselines/simple_cnn.py +scenic/projects/baselines/unet.py +scenic/projects/baselines/vit.py +scenic/projects/baselines/bert/__init__.py +scenic/projects/baselines/bert/bert_base_model.py +scenic/projects/baselines/bert/layers.py +scenic/projects/baselines/bert/main.py +scenic/projects/baselines/bert/model.py +scenic/projects/baselines/bert/train_utils.py +scenic/projects/baselines/bert/trainer.py +scenic/projects/baselines/bert/configs/__init__.py +scenic/projects/baselines/bert/configs/bert_pretraining_config.py +scenic/projects/baselines/bert/configs/glue/__init__.py +scenic/projects/baselines/bert/configs/glue/bert_glue_config.py +scenic/projects/baselines/bert/configs/glue/glue_common.py +scenic/projects/baselines/bert/configs/glue/glue_fewshot.py +scenic/projects/baselines/bert/configs/glue/tasks/__init__.py +scenic/projects/baselines/bert/configs/glue/tasks/bert_cola_config.py +scenic/projects/baselines/bert/configs/glue/tasks/bert_mnli_matched_config.py +scenic/projects/baselines/bert/configs/glue/tasks/bert_mnli_mismatched_config.py +scenic/projects/baselines/bert/configs/glue/tasks/bert_mrpc_config.py +scenic/projects/baselines/bert/configs/glue/tasks/bert_qnli_config.py +scenic/projects/baselines/bert/configs/glue/tasks/bert_qqp_config.py +scenic/projects/baselines/bert/configs/glue/tasks/bert_rte_config.py +scenic/projects/baselines/bert/configs/glue/tasks/bert_sst2_config.py +scenic/projects/baselines/bert/configs/glue/tasks/bert_stsb_config.py +scenic/projects/baselines/bert/configs/glue/tasks/bert_wnli_config.py +scenic/projects/baselines/bert/datasets/__init__.py +scenic/projects/baselines/bert/datasets/bert_glue_dataset.py +scenic/projects/baselines/bert/datasets/bert_wikibooks_dataset.py +scenic/projects/baselines/centernet/__init__.py +scenic/projects/baselines/centernet/evaluate.py +scenic/projects/baselines/centernet/evaluators.py +scenic/projects/baselines/centernet/input_pipeline.py +scenic/projects/baselines/centernet/main.py +scenic/projects/baselines/centernet/optimizer_utils.py +scenic/projects/baselines/centernet/train_utils.py +scenic/projects/baselines/centernet/trainer.py +scenic/projects/baselines/centernet/transforms.py +scenic/projects/baselines/centernet/configs/__init__.py +scenic/projects/baselines/centernet/configs/centernet2_CXT_LSJ_4x.py +scenic/projects/baselines/centernet/configs/centernet2_O365_ViTDetH_LSJ_75e.py +scenic/projects/baselines/centernet/configs/centernet2_ViTDetB_LSJ_4x.py +scenic/projects/baselines/centernet/configs/centernet2_ViTDetH_LSJ_25e_ftO365.py +scenic/projects/baselines/centernet/configs/centernet2_ViTDetH_LSJ_75e.py +scenic/projects/baselines/centernet/configs/centernet2_ViTDetL_LSJ_100e.py +scenic/projects/baselines/centernet/configs/centernet_ViTDetB_LSJ_4x.py +scenic/projects/baselines/centernet/configs/centernet_ViTDetB_S4_LSJ_4x.py +scenic/projects/baselines/centernet/modeling/__init__.py +scenic/projects/baselines/centernet/modeling/box_head.py +scenic/projects/baselines/centernet/modeling/centernet.py +scenic/projects/baselines/centernet/modeling/centernet2.py +scenic/projects/baselines/centernet/modeling/centernet_head.py +scenic/projects/baselines/centernet/modeling/centernet_utils.py +scenic/projects/baselines/centernet/modeling/convnext.py +scenic/projects/baselines/centernet/modeling/fpn.py +scenic/projects/baselines/centernet/modeling/iou_assignment.py +scenic/projects/baselines/centernet/modeling/nms.py +scenic/projects/baselines/centernet/modeling/roi_align.py +scenic/projects/baselines/centernet/modeling/roi_head_utils.py +scenic/projects/baselines/centernet/modeling/roi_heads.py +scenic/projects/baselines/centernet/modeling/vitdet.py +scenic/projects/baselines/clip/__init__.py +scenic/projects/baselines/clip/download.py +scenic/projects/baselines/clip/layers.py +scenic/projects/baselines/clip/model.py +scenic/projects/baselines/clip/tokenizer.py +scenic/projects/baselines/configs/__init__.py +scenic/projects/baselines/configs/cityscapes/__init__.py +scenic/projects/baselines/configs/cityscapes/cityscapes_config.py +scenic/projects/baselines/configs/imagenet/__init__.py +scenic/projects/baselines/configs/imagenet/imagenet_augreg_mixer_config.py +scenic/projects/baselines/configs/imagenet/imagenet_augreg_vit_config.py +scenic/projects/baselines/configs/imagenet/imagenet_axial_resnet_config.py +scenic/projects/baselines/configs/imagenet/imagenet_bit_resnet_config.py +scenic/projects/baselines/configs/imagenet/imagenet_resnet_config.py +scenic/projects/baselines/configs/imagenet/imagenet_resnet_randaug_config.py +scenic/projects/baselines/configs/imagenet/imagenet_vit_config.py +scenic/projects/baselines/configs/imagenet/optax_imagenet_augreg_vit_config.py +scenic/projects/baselines/configs/mnist/__init__.py +scenic/projects/baselines/configs/mnist/mnist_config.py +scenic/projects/baselines/detr/__init__.py +scenic/projects/baselines/detr/detr_base_model.py +scenic/projects/baselines/detr/input_pipeline_detection.py +scenic/projects/baselines/detr/main.py +scenic/projects/baselines/detr/model.py +scenic/projects/baselines/detr/train_utils.py +scenic/projects/baselines/detr/trainer.py +scenic/projects/baselines/detr/transforms.py +scenic/projects/baselines/detr/configs/__init__.py +scenic/projects/baselines/detr/configs/detr_config.py +scenic/projects/baselines/detr/configs/detr_sinkhorn_config.py +scenic/projects/baselines/detr/tests/__init__.py +scenic/projects/baselines/detr/tests/test_datasets.py +scenic/projects/baselines/detr/tests/test_detr_base_model.py +scenic/projects/baselines/detr/tests/test_model.py +scenic/projects/baselines/detr/tests/test_train_utils.py +scenic/projects/baselines/detr/tests/test_transforms.py +scenic/projects/baselines/detr/tests/test_util.py +scenic/projects/baselines/segment_anything/__init__.py +scenic/projects/baselines/segment_anything/demo_utils.py +scenic/projects/baselines/segment_anything/modeling/__init__.py +scenic/projects/baselines/segment_anything/modeling/image_encoder.py +scenic/projects/baselines/segment_anything/modeling/mask_decoder.py +scenic/projects/baselines/segment_anything/modeling/nms.py +scenic/projects/baselines/segment_anything/modeling/prompt_encoder.py +scenic/projects/baselines/segment_anything/modeling/sam.py +scenic/projects/baselines/segment_anything/modeling/transformer.py +scenic/projects/baselines/segment_anything/modeling/utils.py +scenic/projects/baselines/tests/__init__.py +scenic/projects/baselines/tests/test_axial_resnet.py +scenic/projects/baselines/tests/test_mixer.py +scenic/projects/baselines/tests/test_unet.py +scenic/projects/baselines/tests/test_vit.py +scenic/projects/boundary_attention/__init__.py +scenic/projects/boundary_attention/eval_main.py +scenic/projects/boundary_attention/eval_manager.py +scenic/projects/boundary_attention/main.py +scenic/projects/boundary_attention/train_utils.py +scenic/projects/boundary_attention/trainer.py +scenic/projects/boundary_attention/types.py +scenic/projects/boundary_attention/configs/__init__.py +scenic/projects/boundary_attention/configs/base_config.py +scenic/projects/boundary_attention/configs/boundary_attention_model_config.py +scenic/projects/boundary_attention/configs/dataset_configs.py +scenic/projects/boundary_attention/configs/deformable_boundary_attention_model_config.py +scenic/projects/boundary_attention/configs/kaleidoshapes_config.py +scenic/projects/boundary_attention/configs/model_configs.py +scenic/projects/boundary_attention/configs/training_config.py +scenic/projects/boundary_attention/dataset_lib/__init__.py +scenic/projects/boundary_attention/dataset_lib/dataloader.py +scenic/projects/boundary_attention/dataset_lib/datasets/__init__.py +scenic/projects/boundary_attention/dataset_lib/datasets/dataset_utils.py +scenic/projects/boundary_attention/dataset_lib/datasets/kaleidoshapes_dataset.py +scenic/projects/boundary_attention/dataset_lib/datasets/kaleidoshapes_dataset_utils.py +scenic/projects/boundary_attention/dataset_lib/datasets/pickled_dataset.py +scenic/projects/boundary_attention/field_of_junctions_jax/__init__.py +scenic/projects/boundary_attention/field_of_junctions_jax/field_of_junctions.py +scenic/projects/boundary_attention/field_of_junctions_jax/foj_helpers.py +scenic/projects/boundary_attention/helpers/__init__.py +scenic/projects/boundary_attention/helpers/additive_noise_model.py +scenic/projects/boundary_attention/helpers/get_input_opts.py +scenic/projects/boundary_attention/helpers/junction_functions.py +scenic/projects/boundary_attention/helpers/params2maps.py +scenic/projects/boundary_attention/helpers/perlin_noise.py +scenic/projects/boundary_attention/helpers/render_junctions.py +scenic/projects/boundary_attention/helpers/test_new_images.py +scenic/projects/boundary_attention/helpers/train_utils.py +scenic/projects/boundary_attention/helpers/viz_utils.py +scenic/projects/boundary_attention/kaleidoshapes/__init__.py +scenic/projects/boundary_attention/kaleidoshapes/kaleidoshapes.py +scenic/projects/boundary_attention/kaleidoshapes/make_kaleido_image.py +scenic/projects/boundary_attention/kaleidoshapes/plot_image.py +scenic/projects/boundary_attention/loss_lib/__init__.py +scenic/projects/boundary_attention/loss_lib/boundary_attention_loss.py +scenic/projects/boundary_attention/loss_lib/metrics.py +scenic/projects/boundary_attention/loss_lib/metrics_dict.py +scenic/projects/boundary_attention/models/__init__.py +scenic/projects/boundary_attention/models/all_models.py +scenic/projects/boundary_attention/models/boundary_attention.py +scenic/projects/boundary_attention/models/model_lib/__init__.py +scenic/projects/boundary_attention/models/model_lib/attention_blocks.py +scenic/projects/boundary_attention/models/model_lib/boundary_attention_model_base.py +scenic/projects/boundary_attention/models/model_lib/deformable_attention_blocks.py +scenic/projects/boundary_attention/models/model_lib/deformable_attention_utils.py +scenic/projects/boundary_attention/models/model_lib/deformable_refinement_blocks.py +scenic/projects/boundary_attention/models/model_lib/initialization_blocks.py +scenic/projects/boundary_attention/models/model_lib/misc_blocks.py +scenic/projects/boundary_attention/models/model_lib/model_utils.py +scenic/projects/boundary_attention/models/model_lib/patch_mixer_blocks.py +scenic/projects/boundary_attention/models/model_lib/refinement_blocks.py +scenic/projects/boundary_attention/models/model_lib/rope_embedding.py +scenic/projects/densevoc/__init__.py +scenic/projects/densevoc/chota.py +scenic/projects/densevoc/densevoc_evaluator.py +scenic/projects/densevoc/evaluate.py +scenic/projects/densevoc/evaluation_utils.py +scenic/projects/densevoc/input_pipeline.py +scenic/projects/densevoc/input_utils.py +scenic/projects/densevoc/main.py +scenic/projects/densevoc/trainer.py +scenic/projects/densevoc/transforms.py +scenic/projects/densevoc/vidstg_evaluator.py +scenic/projects/densevoc/configs/__init__.py +scenic/projects/densevoc/configs/common.py +scenic/projects/densevoc/configs/densevoc_disjoint_pretraining.py +scenic/projects/densevoc/configs/densevoc_vidstg.py +scenic/projects/densevoc/configs/densevoc_vidstg_ftgrit_hard_aggregation.py +scenic/projects/densevoc/configs/densevoc_vidstg_ftgrit_soft_aggregation.py +scenic/projects/densevoc/configs/densevoc_vidstg_videoeval.py +scenic/projects/densevoc/configs/densevoc_vln.py +scenic/projects/densevoc/configs/grit_vg_384.py +scenic/projects/densevoc/modeling/__init__.py +scenic/projects/densevoc/modeling/auto_regressive_decode.py +scenic/projects/densevoc/modeling/densevoc_model.py +scenic/projects/densevoc/modeling/grit.py +scenic/projects/densevoc/modeling/text_decoder.py +scenic/projects/densevoc/modeling/tracking_layers.py +scenic/projects/densevoc/modeling/tracking_utils.py +scenic/projects/fast_vit/__init__.py +scenic/projects/fast_vit/main.py +scenic/projects/fast_vit/model_utils.py +scenic/projects/fast_vit/xvit.py +scenic/projects/fast_vit/tests/__init__.py +scenic/projects/fast_vit/tests/test_model_utils.py +scenic/projects/gerald/__init__.py +scenic/projects/gerald/ger_eval.py +scenic/projects/gerald/ger_trainer.py +scenic/projects/gerald/input_pipeline.py +scenic/projects/gerald/main.py +scenic/projects/gerald/prepare_ald_codes.py +scenic/projects/gerald/utils.py +scenic/projects/gerald/configs/__init__.py +scenic/projects/gerald/configs/gerald_finetuning_config.py +scenic/projects/gerald/configs/gerald_pretraining_config.py +scenic/projects/gerald/models/__init__.py +scenic/projects/gerald/models/ger_model.py +scenic/projects/gerald/models/git_vit.py +scenic/projects/gerald/models/text_decoder.py +scenic/projects/knowledge_visual_language/__init__.py +scenic/projects/knowledge_visual_language/main.py +scenic/projects/knowledge_visual_language/trainer.py +scenic/projects/knowledge_visual_language/trainer_memory.py +scenic/projects/knowledge_visual_language/trainer_utils.py +scenic/projects/knowledge_visual_language/configs/__init__.py +scenic/projects/knowledge_visual_language/configs/finetune_okvqa_base.py +scenic/projects/knowledge_visual_language/configs/wit_memory_G.py +scenic/projects/knowledge_visual_language/configs/wit_memory_base.py +scenic/projects/knowledge_visual_language/configs/wit_memory_g.py +scenic/projects/knowledge_visual_language/configs/wit_memory_large.py +scenic/projects/knowledge_visual_language/configs/wit_retrieval_soft_G_froze_config.py +scenic/projects/knowledge_visual_language/configs/wit_retrieval_soft_base_config.py +scenic/projects/knowledge_visual_language/configs/wit_retrieval_soft_base_froze_config.py +scenic/projects/knowledge_visual_language/configs/wit_retrieval_soft_g_froze_config.py +scenic/projects/knowledge_visual_language/configs/wit_retrieval_soft_large_config.py +scenic/projects/knowledge_visual_language/configs/wit_retrieval_soft_large_froze_config.py +scenic/projects/knowledge_visual_language/data/__init__.py +scenic/projects/knowledge_visual_language/data/cc12m_generation_dataset.py +scenic/projects/knowledge_visual_language/data/cc12m_table_dataset.py +scenic/projects/knowledge_visual_language/data/data_utils.py +scenic/projects/knowledge_visual_language/data/okvqa_dataset.py +scenic/projects/knowledge_visual_language/data/vqa_dataset.py +scenic/projects/knowledge_visual_language/data/vqa_table_dataset.py +scenic/projects/knowledge_visual_language/data/web_image_text_generation_dataset.py +scenic/projects/knowledge_visual_language/data/wiki_image_text_generation_dataset.py +scenic/projects/knowledge_visual_language/data/wit_table_dataset.py +scenic/projects/lang4video/__init__.py +scenic/projects/lang4video/configs/__init__.py +scenic/projects/lang4video/configs/datasets/__init__.py +scenic/projects/lang4video/configs/train/__init__.py +scenic/projects/lang4video/configs/zero_shot/__init__.py +scenic/projects/lang4video/model/__init__.py +scenic/projects/lang4video/trainer/__init__.py +scenic/projects/layout_denoise/__init__.py +scenic/projects/layout_denoise/base_model.py +scenic/projects/layout_denoise/main.py +scenic/projects/layout_denoise/model.py +scenic/projects/layout_denoise/train_utils.py +scenic/projects/layout_denoise/trainer.py +scenic/projects/layout_denoise/configs/__init__.py +scenic/projects/layout_denoise/configs/dataset_config.py +scenic/projects/layout_denoise/configs/detr.py +scenic/projects/layout_denoise/datasets/__init__.py +scenic/projects/layout_denoise/datasets/dataset.py +scenic/projects/layout_denoise/datasets/parsers.py +scenic/projects/layout_denoise/layers/__init__.py +scenic/projects/layout_denoise/layers/common.py +scenic/projects/layout_denoise/layers/embedding.py +scenic/projects/layout_denoise/layers/predictor.py +scenic/projects/layout_denoise/layers/transformer.py +scenic/projects/loca/__init__.py +scenic/projects/loca/loca_dataset.py +scenic/projects/loca/main.py +scenic/projects/loca/ops.py +scenic/projects/loca/trainer.py +scenic/projects/loca/utils.py +scenic/projects/loca/vit.py +scenic/projects/loca/configs/__init__.py +scenic/projects/loca/configs/loca_imnet1k_base16.py +scenic/projects/matvit/__init__.py +scenic/projects/matvit/classification_eval_main.py +scenic/projects/matvit/layers.py +scenic/projects/matvit/main.py +scenic/projects/matvit/matvit.py +scenic/projects/matvit/trainer.py +scenic/projects/mbt/__init__.py +scenic/projects/mbt/main.py +scenic/projects/mbt/model.py +scenic/projects/mbt/model_utils.py +scenic/projects/mbt/train_utils.py +scenic/projects/mbt/trainer.py +scenic/projects/mbt/configs/__init__.py +scenic/projects/mbt/configs/audioset/__init__.py +scenic/projects/mbt/configs/audioset/balanced_audioset_base.py +scenic/projects/mbt/datasets/__init__.py +scenic/projects/mbt/datasets/audiovisual_tfrecord_dataset.py +scenic/projects/mbt/datasets/dataset_utils.py +scenic/projects/mtv/__init__.py +scenic/projects/mtv/config_utils.py +scenic/projects/mtv/config_utils_test.py +scenic/projects/mtv/main.py +scenic/projects/mtv/model.py +scenic/projects/mtv/model_test.py +scenic/projects/mtv/model_utils.py +scenic/projects/mtv/model_utils_test.py +scenic/projects/mtv/train_utils.py +scenic/projects/mtv/trainer.py +scenic/projects/mtv/configs/__init__.py +scenic/projects/mtv/configs/epic_kitchens/__init__.py +scenic/projects/mtv/configs/epic_kitchens/epic_mtv_b2_cva.py +scenic/projects/mtv/configs/kinetics/__init__.py +scenic/projects/mtv/configs/kinetics/k400_mtv_b2_cva.py +scenic/projects/mtv/configs/kinetics/k600_mtv_b2_cva.py +scenic/projects/mtv/configs/kinetics/k600_mtv_l2_cva.py +scenic/projects/mtv/configs/kinetics/k700_mtv_b2_cva.py +scenic/projects/mtv/configs/kinetics/k700_mtv_l2_cva.py +scenic/projects/mtv/configs/mit/__init__.py +scenic/projects/mtv/configs/mit/mit_mtv_l2_cva.py +scenic/projects/mtv/configs/ssv2/__init__.py +scenic/projects/mtv/configs/ssv2/ssv2_mtv_b2_cva.py +scenic/projects/ncr/__init__.py +scenic/projects/ncr/base_model.py +scenic/projects/ncr/classification_trainer.py +scenic/projects/ncr/loss.py +scenic/projects/ncr/main.py +scenic/projects/ncr/resnet.py +scenic/projects/ncr/utils.py +scenic/projects/ncr/configs/__init__.py +scenic/projects/ncr/configs/mini_imagenet_blue_baseline.py +scenic/projects/ncr/configs/mini_imagenet_blue_ncr_train00.py +scenic/projects/ncr/configs/mini_imagenet_blue_ncr_train20.py +scenic/projects/ncr/configs/mini_imagenet_blue_ncr_train40.py +scenic/projects/ncr/configs/mini_imagenet_blue_ncr_train80.py +scenic/projects/ncr/configs/mini_imagenet_red_baseline.py +scenic/projects/ncr/configs/mini_imagenet_red_ncr_train00.py +scenic/projects/ncr/configs/mini_imagenet_red_ncr_train20.py +scenic/projects/ncr/configs/mini_imagenet_red_ncr_train40.py +scenic/projects/ncr/configs/mini_imagenet_red_ncr_train80.py +scenic/projects/ncr/data/__init__.py +scenic/projects/objectvivit/__init__.py +scenic/projects/objectvivit/dataset_utils.py +scenic/projects/objectvivit/datasets.py +scenic/projects/objectvivit/main.py +scenic/projects/objectvivit/model.py +scenic/projects/objectvivit/model_utils.py +scenic/projects/objectvivit/object_attention.py +scenic/projects/objectvivit/optimizer_utils.py +scenic/projects/objectvivit/train_utils.py +scenic/projects/objectvivit/trainer.py +scenic/projects/objectvivit/configs/__init__.py +scenic/projects/objectvivit/configs/ssv2_B16_baseline.py +scenic/projects/objectvivit/configs/ssv2_B16_object.py +scenic/projects/objectvivit/configs/ssv2_B16_sampling.py +scenic/projects/omninet/__init__.py +scenic/projects/omninet/main.py +scenic/projects/omninet/model.py +scenic/projects/omninet/model_utils.py +scenic/projects/omninet/tests/__init__.py +scenic/projects/omninet/tests/test_model.py +scenic/projects/owl_vit/__init__.py +scenic/projects/owl_vit/evaluator.py +scenic/projects/owl_vit/layers.py +scenic/projects/owl_vit/losses.py +scenic/projects/owl_vit/main.py +scenic/projects/owl_vit/matching_base_models.py +scenic/projects/owl_vit/models.py +scenic/projects/owl_vit/trainer.py +scenic/projects/owl_vit/utils.py +scenic/projects/owl_vit/clip/__init__.py +scenic/projects/owl_vit/clip/layers.py +scenic/projects/owl_vit/clip/model.py +scenic/projects/owl_vit/clip/tokenizer.py +scenic/projects/owl_vit/configs/__init__.py +scenic/projects/owl_vit/configs/clip_b16.py +scenic/projects/owl_vit/configs/clip_b32.py +scenic/projects/owl_vit/configs/clip_b32_finetune.py +scenic/projects/owl_vit/configs/clip_l14.py +scenic/projects/owl_vit/configs/clip_l14_with_masks.py +scenic/projects/owl_vit/configs/owl_v2_clip_b16.py +scenic/projects/owl_vit/configs/owl_v2_clip_l14.py +scenic/projects/owl_vit/data/__init__.py +scenic/projects/owl_vit/notebooks/__init__.py +scenic/projects/owl_vit/notebooks/inference.py +scenic/projects/owl_vit/notebooks/interactive.py +scenic/projects/owl_vit/notebooks/numpy_cache.py +scenic/projects/owl_vit/notebooks/plotting.py +scenic/projects/owl_vit/notebooks/tests/__init__.py +scenic/projects/owl_vit/notebooks/tests/inference_test.py +scenic/projects/owl_vit/notebooks/tests/interactive_test.py +scenic/projects/owl_vit/notebooks/tests/numpy_cache_test.py +scenic/projects/owl_vit/notebooks/tests/plotting_test.py +scenic/projects/owl_vit/preprocessing/__init__.py +scenic/projects/owl_vit/preprocessing/image_ops.py +scenic/projects/owl_vit/preprocessing/input_pipeline.py +scenic/projects/owl_vit/preprocessing/label_ops.py +scenic/projects/owl_vit/preprocessing/modalities.py +scenic/projects/owl_vit/preprocessing/mosaic.py +scenic/projects/owl_vit/preprocessing/transforms.py +scenic/projects/owl_vit/tests/__init__.py +scenic/projects/owl_vit/tests/checkpoint_loading_test.py +scenic/projects/owl_vit/tests/layers_test.py +scenic/projects/owl_vit/tests/models_test.py +scenic/projects/owl_vit/tests/utils_test.py +scenic/projects/pixel_llm/__init__.py +scenic/projects/pixel_llm/auto_regressive_decode.py +scenic/projects/pixel_llm/densecap_evaluator.py +scenic/projects/pixel_llm/evaluate.py +scenic/projects/pixel_llm/evaluators.py +scenic/projects/pixel_llm/main.py +scenic/projects/pixel_llm/partition_utils.py +scenic/projects/pixel_llm/tokenizers.py +scenic/projects/pixel_llm/train_utils.py +scenic/projects/pixel_llm/trainer.py +scenic/projects/pixel_llm/configs/__init__.py +scenic/projects/pixel_llm/configs/common.py +scenic/projects/pointcloud/__init__.py +scenic/projects/pointcloud/main.py +scenic/projects/pointcloud/main_s3dis.py +scenic/projects/pointcloud/main_seg.py +scenic/projects/pointcloud/models.py +scenic/projects/pointcloud/models_test.py +scenic/projects/pointcloud/pointcloud_dataset.py +scenic/projects/pointcloud/s3dis_dataset.py +scenic/projects/pointcloud/segmentation_model.py +scenic/projects/pointcloud/segmentation_trainer.py +scenic/projects/pointcloud/shapenet_dataset.py +scenic/projects/pointcloud/configs/__init__.py +scenic/projects/pointcloud/configs/pct_config.py +scenic/projects/pointcloud/configs/pct_segmentation_s3dis.py +scenic/projects/pointcloud/configs/pct_segmentation_shapenet.py +scenic/projects/polyvit/__init__.py +scenic/projects/polyvit/layers.py +scenic/projects/polyvit/main.py +scenic/projects/polyvit/model.py +scenic/projects/polyvit/model_utils.py +scenic/projects/polyvit/polyvit_base_model.py +scenic/projects/polyvit/train_utils.py +scenic/projects/polyvit/trainer.py +scenic/projects/polyvit/configs/__init__.py +scenic/projects/polyvit/configs/polyvit_all.py +scenic/projects/polyvit/tests/__init__.py +scenic/projects/polyvit/tests/test_layers.py +scenic/projects/robust_segvit/__init__.py +scenic/projects/robust_segvit/datasets/__init__.py +scenic/projects/robust_segvit/datasets/cityscapes_variants.py +scenic/projects/robust_segvit/datasets/datasets_info.py +scenic/projects/robust_segvit/datasets/denoise_utils.py +scenic/projects/robust_segvit/datasets/segmentation_datasets.py +scenic/projects/robust_segvit/datasets/segmentation_variants.py +scenic/projects/robust_segvit/tests/__init__.py +scenic/projects/robust_segvit/tests/segmentation_datasets_test.py +scenic/projects/robust_segvit/tests/segmentation_variants_test.py +scenic/projects/streaming_dvc/__init__.py +scenic/projects/streaming_dvc/caption_evaluator.py +scenic/projects/streaming_dvc/cococap_eval.py +scenic/projects/streaming_dvc/densecap_evaluator.py +scenic/projects/streaming_dvc/evaluate.py +scenic/projects/streaming_dvc/main.py +scenic/projects/streaming_dvc/optimizer_utils.py +scenic/projects/streaming_dvc/partition_utils.py +scenic/projects/streaming_dvc/post_processing_utils.py +scenic/projects/streaming_dvc/train_utils.py +scenic/projects/streaming_dvc/trainer.py +scenic/projects/streaming_dvc/configs/__init__.py +scenic/projects/streaming_dvc/configs/common.py +scenic/projects/streaming_dvc/configs/git_anet_paragraph_streaming_input.py +scenic/projects/streaming_dvc/configs/git_anet_streaming_input_output.py +scenic/projects/streaming_dvc/configs/git_vitt_streaming_input_output.py +scenic/projects/streaming_dvc/configs/git_youcook2_paragraph_streaming_input.py +scenic/projects/streaming_dvc/configs/git_youcook2_streaming_input_output.py +scenic/projects/streaming_dvc/configs/vid2seq_anet_streaming_input_output.py +scenic/projects/streaming_dvc/configs/vid2seq_vitt_streaming_input_output.py +scenic/projects/streaming_dvc/configs/vid2seq_youcook2_streaming_input_output.py +scenic/projects/streaming_dvc/io/__init__.py +scenic/projects/streaming_dvc/io/densecap_ops.py +scenic/projects/streaming_dvc/io/flexio.py +scenic/projects/streaming_dvc/io/ops.py +scenic/projects/streaming_dvc/modeling/__init__.py +scenic/projects/streaming_dvc/modeling/auto_regressive_decode.py +scenic/projects/streaming_dvc/modeling/model.py +scenic/projects/streaming_dvc/modeling/streaming_model.py +scenic/projects/streaming_dvc/modeling/streaming_utils.py +scenic/projects/streaming_dvc/modeling/text_decoder.py +scenic/projects/streaming_dvc/modeling/vid2seq_model.py +scenic/projects/streaming_dvc/modeling/vit.py +scenic/projects/svvit/__init__.py +scenic/projects/svvit/classification_trainer.py +scenic/projects/svvit/inference.py +scenic/projects/svvit/main.py +scenic/projects/svvit/metrics.py +scenic/projects/svvit/transfer_trainer.py +scenic/projects/svvit/vit.py +scenic/projects/svvit/xvit.py +scenic/projects/svvit/configs/__init__.py +scenic/projects/svvit/configs/pileup_coverage_vit_config.py +scenic/projects/svvit/configs/pileup_coverage_xvit_config.py +scenic/projects/svvit/configs/vit_config.py +scenic/projects/svvit/configs/vit_finetuning_config.py +scenic/projects/svvit/configs/xvit_config.py +scenic/projects/svvit/configs/xvit_config_eval.py +scenic/projects/svvit/configs/xvit_finetuning_config.py +scenic/projects/svvit/datasets/__init__.py +scenic/projects/svvit/datasets/pileup_coverage_dataset.py +scenic/projects/svvit/datasets/pileup_window_dataset.py +scenic/projects/svvit/tests/__init__.py +scenic/projects/svvit/tests/metrics_test.py +scenic/projects/t5/__init__.py +scenic/projects/t5/inspect_model.py +scenic/projects/t5/layers.py +scenic/projects/t5/model.py +scenic/projects/t5/tokenizer.py +scenic/projects/token_learner/__init__.py +scenic/projects/token_learner/main.py +scenic/projects/token_learner/model.py +scenic/projects/token_learner/configs/__init__.py +scenic/projects/token_learner/configs/im1k_token_learner_config.py +scenic/projects/token_learner/data/__init__.py +scenic/projects/token_learner/tests/__init__.py +scenic/projects/token_learner/tests/test_model.py +scenic/projects/verbs_in_action/__init__.py +scenic/projects/verbs_in_action/clip4clip_model.py +scenic/projects/verbs_in_action/losses.py +scenic/projects/verbs_in_action/main.py +scenic/projects/verbs_in_action/tfrecord_dataset.py +scenic/projects/verbs_in_action/trainer.py +scenic/projects/verbs_in_action/utils.py +scenic/projects/verbs_in_action/configs/__init__.py +scenic/projects/verbs_in_action/configs/baseline.py +scenic/projects/verbs_in_action/configs/vfc.py +scenic/projects/vid2seq/__init__.py +scenic/projects/vid2seq/data_utils.py +scenic/projects/vid2seq/dvc_eval.py +scenic/projects/vid2seq/generate_from_file.py +scenic/projects/vid2seq/load_utils.py +scenic/projects/vid2seq/main.py +scenic/projects/vid2seq/models.py +scenic/projects/vid2seq/train_utils.py +scenic/projects/vid2seq/trainer.py +scenic/projects/vid2seq/configs/__init__.py +scenic/projects/vid2seq/configs/activitynet-captions.py +scenic/projects/vid2seq/configs/youcook2.py +scenic/projects/vid2seq/configs/yttemporal.py +scenic/projects/vid2seq/datasets/__init__.py +scenic/projects/vid2seq/datasets/dense_video_captioning_tfrecord_dataset.py +scenic/projects/vivit/__init__.py +scenic/projects/vivit/evaluation_lib.py +scenic/projects/vivit/main.py +scenic/projects/vivit/model.py +scenic/projects/vivit/model_utils.py +scenic/projects/vivit/train_utils.py +scenic/projects/vivit/trainer.py +scenic/projects/vivit/configs/__init__.py +scenic/projects/vivit/configs/epic_kitchens/__init__.py +scenic/projects/vivit/configs/epic_kitchens/vivit_large_factorised_encoder.py +scenic/projects/vivit/configs/kinetics400/__init__.py +scenic/projects/vivit/configs/kinetics400/vivit_base_factorised_encoder.py +scenic/projects/vivit/configs/kinetics400/vivit_base_k400.py +scenic/projects/vivit/configs/kinetics400/vivit_large_factorised_encoder.py +scenic/projects/vivit/configs/kinetics600/__init__.py +scenic/projects/vivit/configs/kinetics600/vivit_large_factorised_encoder.py +scenic/projects/vivit/configs/something_something_v2/__init__.py +scenic/projects/vivit/configs/something_something_v2/vivit_large_factorised_encoder.py +scenic/projects/vivit/data/__init__.py +scenic/projects/vivit/data/file_utils.py +scenic/projects/vivit/data/video_tfrecord_dataset.py +scenic/projects/vivit/data/tests/__init__.py +scenic/projects/vivit/data/tests/test_video_tfrecord_dataset.py +scenic/projects/vivit/tests/__init__.py +scenic/projects/vivit/tests/test_vivit_metrics.py +scenic/projects/vivit/tests/test_vivit_trainer.py +scenic/train_lib/__init__.py +scenic/train_lib/classification_trainer.py +scenic/train_lib/lr_schedules.py +scenic/train_lib/optax.py +scenic/train_lib/optimizers.py +scenic/train_lib/pretrain_utils.py +scenic/train_lib/train_utils.py +scenic/train_lib/trainers.py +scenic/train_lib/tests/__init__.py +scenic/train_lib/tests/test_classification_trainer.py +scenic/train_lib/tests/test_lr_schedules.py +scenic/train_lib/tests/test_optax.py +scenic/train_lib/tests/test_optimizers.py +scenic/train_lib/transfer/__init__.py +scenic/train_lib/transfer/fewshot_utils.py +scenic/train_lib/transfer/linear_probe_utils.py +scenic/train_lib/transfer/transfer_trainer.py +scenic/train_lib/transfer/tests/__init__.py +scenic/train_lib/transfer/tests/test_fewshot_utils.py \ No newline at end of file diff --git a/scenic.egg-info/dependency_links.txt b/scenic.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/scenic.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/scenic.egg-info/requires.txt b/scenic.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..c704e5f407c9042b151e6dc5a5dcb5b5c58041da --- /dev/null +++ b/scenic.egg-info/requires.txt @@ -0,0 +1,20 @@ +absl-py>=1.0.0 +numpy>=1.12 +jax>=0.4.3 +jaxlib>=0.4.3 +flax>=0.4.0 +ml-collections>=0.1.1 +tensorflow>=2.7 +immutabledict>=2.2.1 +clu>=0.0.6 +tensorflow-datasets +optax@ git+https://github.com/google-deepmind/optax.git@main + +[testing] +pytest +shapely +ott-jax>=0.2.0 +sklearn +lingvo==0.12.6 +seaborn>=0.11.2 +dmvr@ git+https://github.com/google-deepmind/dmvr.git diff --git a/scenic.egg-info/top_level.txt b/scenic.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..32e23e34877ac602f342a04b2f0ff52be2bede12 --- /dev/null +++ b/scenic.egg-info/top_level.txt @@ -0,0 +1 @@ +scenic diff --git a/scenic/__init__.py b/scenic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/__pycache__/__init__.cpython-310.pyc b/scenic/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2cc1dcd43e1ce70ee6a85dd48b33698638c3542 Binary files /dev/null and b/scenic/__pycache__/__init__.cpython-310.pyc differ diff --git a/scenic/app.py b/scenic/app.py new file mode 100644 index 0000000000000000000000000000000000000000..7ccd83dbfae3ad7b0af718f7220c0ba989e10ddd --- /dev/null +++ b/scenic/app.py @@ -0,0 +1,109 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic entry point for Python application in Scenic. + +This provides run() which performs some initialization and then calls the +provided main with a JAX PRNGKey, the ConfigDict, the working directory +and a CLU MetricWriter. +We expect each scenic project to have its own main.py. It's very short but +makes it easier to maintain scenic as the number of projects grows. + +Usage in your main.py: + from scenic import app + + def main(rng: jnp.ndarray, + config: ml_collections.ConfigDict, + workdir: str, + writer: metric_writers.MetricWriter): + # Call the library that trains your model. + + if __name__ == '__main__': + app.run(main) +""" +import functools +import os + +from absl import app +from absl import flags +from absl import logging + +from clu import metric_writers +from clu import platform +import flax +import flax.linen as nn +import jax +from ml_collections import config_flags +import tensorflow as tf + +FLAGS = flags.FLAGS + +# These are general flags that are used across most of scenic projects. These +# flags can be accessed via `flags.FLAGS.` and projects can also +# define their own flags in their `main.py`. +config_flags.DEFINE_config_file( + 'config', None, 'Training configuration.', lock_config=False) +flags.DEFINE_string('workdir', None, 'Work unit directory.') +flags.DEFINE_string('dataset_service_address', None, + 'Address of the tf.data service') +flags.mark_flags_as_required(['config', 'workdir']) + +flax.config.update('flax_use_orbax_checkpointing', False) + + +def run(main): + # Provide access to --jax_backend_target and --jax_xla_backend flags. + jax.config.config_with_absl() + app.run(functools.partial(_run_main, main=main)) + + +def _run_main(argv, *, main): + """Runs the `main` method after some initial setup.""" + del argv + # Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make + # it unavailable to JAX. + tf.config.experimental.set_visible_devices([], 'GPU') + + config = FLAGS.config + workdir = FLAGS.workdir + if 'workdir_suffix' in config: + workdir = os.path.join(workdir, config.workdir_suffix) + + # Enable wrapping of all module calls in a named_call for easier profiling: + nn.enable_named_call() + + if FLAGS.jax_backend_target: + logging.info('Using JAX backend target %s', FLAGS.jax_backend_target) + jax_xla_backend = ('None' if FLAGS.jax_xla_backend is None else + FLAGS.jax_xla_backend) + logging.info('Using JAX XLA backend %s', jax_xla_backend) + + logging.info('JAX host: %d / %d', jax.process_index(), jax.process_count()) + logging.info('JAX devices: %r', jax.devices()) + + # Add a note so that we can tell which task is which JAX host. + # (task 0 is not guaranteed to be the host 0) + platform.work_unit().set_task_status( + f'host_id: {jax.process_index()}, host_count: {jax.process_count()}') + if jax.process_index() == 0: + platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, + workdir, 'Workdir') + + rng = jax.random.PRNGKey(config.rng_seed) + logging.info('RNG: %s', rng) + + writer = metric_writers.create_default_writer( + workdir, just_logging=jax.process_index() > 0, asynchronous=True) + + main(rng=rng, config=config, workdir=workdir, writer=writer) diff --git a/scenic/common_lib/__init__.py b/scenic/common_lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/common_lib/__pycache__/__init__.cpython-310.pyc b/scenic/common_lib/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/common_lib/__pycache__/debug_utils.cpython-310.pyc b/scenic/common_lib/__pycache__/debug_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/common_lib/colabs/__init__.py b/scenic/common_lib/colabs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/common_lib/colabs/scenic_playground.ipynb b/scenic/common_lib/colabs/scenic_playground.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/common_lib/common_utils.py b/scenic/common_lib/common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc002dc1ec7667f636a871840ede66ae90c091e --- /dev/null +++ b/scenic/common_lib/common_utils.py @@ -0,0 +1,93 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions that do not fit into other common_lib modules.""" + +import importlib +import types + +from absl import logging +import jax +import ml_collections + + +def recursive_reload(module: types.ModuleType, package_restrict: str): + """Recursively reload a module and the modules it imports. + + Args: + module: The module to reload. + + package_restrict: Only modules with this prefix will be reloaded. For + example, if package_restrict is "scenic.projects", only modules under + scenic.projects will be reloaded. package_restrict must always be set to + avoid reloading of built-in or unrelated packages that should not be + reloaded (e.g. Numpy). + + Returns: + The reloaded module object. + + Raises: + ValueError if package_restrict is empyt. + """ + reloaded = set() + if not package_restrict: + raise ValueError('package_restrict must be non-empty.') + + def reload(m): + if m in reloaded: + return m + reloaded.add(m) + for attribute_name in dir(m): + attribute = getattr(m, attribute_name) + if (isinstance(attribute, types.ModuleType) and + attribute.__name__.startswith(package_restrict)): + reload(attribute) + logging.info('Reloading %s', m.__name__) + return importlib.reload(m) + + return reload(module) + + +def to_config_dict_heuristic( + config: ml_collections.ConfigDict) -> ml_collections.ConfigDict: + """Heuristically converts dicts inside a ConfigDict into ConfigDicts. + + This function detects lists and tuples with dicts and converts those dicts + into ConfigDicts. This will address most failure cases, but the function will + not be able resolve nested cases (e.g. list(dict(list(....)))). + + Arguments: + config: Config to attempt fixing. + + Returns: + Probably fixed config. + """ + def maybe_config_dict(x): + if isinstance(x, dict): + return ml_collections.ConfigDict(x) + return x + + def maybe_config_dict_in_list(x): + if isinstance(x, (list, tuple)): + return jax.tree_util.tree_map( + maybe_config_dict, x, is_leaf=lambda y: isinstance(y, dict) + ) + return x + + config = jax.tree_util.tree_map( + maybe_config_dict_in_list, + config.to_dict(), + is_leaf=lambda x: isinstance(x, list), + ) + return ml_collections.ConfigDict(config) diff --git a/scenic/common_lib/debug_utils.py b/scenic/common_lib/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef66f2dbba03166fe324c0bcbae9083e0cd4612 --- /dev/null +++ b/scenic/common_lib/debug_utils.py @@ -0,0 +1,328 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for logging, debugging, profiling, testing, and visualization.""" + +from collections import abc +from concurrent import futures +import json +import operator +import threading +from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union + +from absl import logging +from clu import parameter_overview +import jax +import jax.numpy as jnp +from jax.tree_util import tree_map +import ml_collections + +PyTree = Any + + +def enable_jax_debugging_flags(): + """Enables some of the global JAX flags for debugging.""" + + # Enable the NaN-checker behavior to cause JAX to hard-break on the first + # occurrence of a NaN. + jax.config.update('jax_debug_nans', True) + + # Enable the compilation logger to check whether or not we're accidentally + # causing a lot of re-compilation (inspect logs for excessive jitting). + jax.config.update('jax_log_compiles', True) + + # Detect numpy-style automatic rank promotion and force strict, explicit + # casts. We can use `raise` instead of warn to raise an error. + jax.config.update('jax_numpy_rank_promotion', 'warn') + + # Print global JAX flags in logs. + logging.info('Global JAX flags: %s', jax.config.values) + + +def log_param_shapes( + params: Any, + print_params_nested_dict: bool = False, + description: Optional[str] = None, + include_stats: bool = True, +) -> int: + """Prints out shape of parameters and total number of trainable parameters. + + Args: + params: PyTree of model parameters. + print_params_nested_dict: If True, it prints parameters in shape of a nested + dict. + description: Optional description to print out before logging the parameter + summary. + include_stats: Include parameter stats if True. + + Returns: + int; Total number of trainable parameters. + """ + if print_params_nested_dict: + shape_dict = tree_map(lambda x: str(x.shape), params) + # We use json.dumps for pretty printing nested dicts. + logging.info( + 'Printing model param shape:/n%s', + json.dumps(shape_dict, sort_keys=True, indent=4), + ) + parameter_overview.log_parameter_overview( + params, include_stats=include_stats, msg=description + ) + total_params = jax.tree_util.tree_reduce( + operator.add, tree_map(lambda x: x.size, params) + ) + logging.info('Total params: %d', total_params) + return total_params + + +def input_spec_to_jax_shape_dtype_struct( + spec: Union[Tuple[Tuple[int, ...], jnp.dtype], Tuple[int, ...]], + batch_size: Optional[int] = None, +) -> jax.ShapeDtypeStruct: + """Parse an input specs into a jax.ShapeDtypeStruct.""" + spec = tuple(spec) + if batch_size and len(spec) == 1: + raise ValueError('batch_size unsupported when len(spec) is 1.') + if len(spec) == 2 and isinstance(spec[0], abc.Iterable): + shape = (batch_size,) + tuple(spec[0][1:]) if batch_size else spec[0] + dtype = spec[1] + else: + shape = (batch_size,) + tuple(spec[1:]) if batch_size else spec + dtype = jnp.float32 + return jax.ShapeDtypeStruct(shape, dtype) + + +def compute_flops( + flax_model_apply_fn: Callable[[jnp.ndarray], Any], + input_spec: Sequence[ + Union[Tuple[Tuple[int, ...], jnp.dtype], Tuple[int, ...], None] + ], + fuse_multiply_add: bool, +) -> float: + """Performs static analysis of the graph to compute theoretical FLOPs. + + One can also use the XProf profiler to get the actual FLOPs at runtime + based on device counters. Theoretical FLOPs are more useful for comparing + models across different library implementations and is hardware-agnostic. + + Args: + flax_model_apply_fn: Apply function of the flax model to be analysed. + input_spec: An iterable of (shape, dtype) pairs specifying the shape and + dtype of the inputs. If unspecified the dtype is float32. + fuse_multiply_add: Bool; If true, count a multiply and add (also known as + "multiply-accumulate" or "MAC") as 1 FLOP rather than 2 (as done by the + HLO analysis). This is commonly used in literature. + + Returns: + flops: The total number of flops. + """ + dummy_input = [] + for spec in input_spec: + if spec is not None: + in_st = input_spec_to_jax_shape_dtype_struct(spec, batch_size=1) + dummy_input.append(jnp.zeros(in_st.shape, in_st.dtype)) + else: + dummy_input.append(None) + + analysis = jax.jit(flax_model_apply_fn).lower(*dummy_input).cost_analysis() + flops = analysis['flops'] + if fuse_multiply_add: + flops = flops / 2 + logging.info('GFLOPs %0.3f for input spec: %s', flops / 10**9, input_spec) + return flops + + +def compute_flops_with_pytree( + flax_model_apply_fn: Callable[[jnp.ndarray], Any], + input_spec: PyTree, + unpack_input: bool = True, + fuse_multiply_add: bool = True, +) -> float: + """Performs static analysis of the graph to compute theoretical FLOPs. + + One can also use the XProf profiler to get the actual FLOPs at runtime + based on device counters. Theoretical FLOPs are more useful for comparing + models across different library implementations and is hardware-agnostic. + + Args: + flax_model_apply_fn: Apply function of the flax model to be analysed. + input_spec: A PyTree whose leaves are (shape, dtype) pairs specifying the + shape and dtype of the inputs. If unspecified the dtype is float32. + unpack_input: Unpack the pytree when feeding it to the model. + fuse_multiply_add: Bool; If true, count a multiply and add (also known as + "multiply-accumulate" or "MAC") as 1 FLOP rather than 2 (as done by the + HLO analysis). This is commonly used in literature. + + Returns: + flops: The total number of flops. + """ + + def check_leaf_spec(spec: Sequence[PyTree]) -> bool: + return ( + len(spec) == 2 + and isinstance(spec[0], abc.Sequence) + and all(isinstance(i, int) for i in spec[0]) + and isinstance(spec[1], jnp.dtype) + ) or (all(isinstance(i, int) for i in spec[0])) + + def create_dummy_input(spec: PyTree) -> PyTree: + if isinstance(spec, dict): + return {k: create_dummy_input(v) for k, v in spec.items()} + elif isinstance(spec, abc.Sequence): + if check_leaf_spec(spec): + in_st = input_spec_to_jax_shape_dtype_struct(spec, batch_size=1) + return jnp.zeros(in_st.shape, in_st.dtype) + else: + return tuple(create_dummy_input(child) for child in spec) + elif spec is None: + return None + else: + raise NotImplementedError('Unsupported spec type.', type(spec)) + + dummy_input = create_dummy_input(input_spec) + + if isinstance(dummy_input, dict) and unpack_input: + analysis = jax.jit(flax_model_apply_fn).lower(**dummy_input).cost_analysis() + elif isinstance(dummy_input, abc.Sequence) and unpack_input: + analysis = jax.jit(flax_model_apply_fn).lower(*dummy_input).cost_analysis() + else: + analysis = jax.jit(flax_model_apply_fn).lower(dummy_input).cost_analysis() + + flops = analysis['flops'] + if fuse_multiply_add: + flops = flops / 2 + logging.info('GFLOPs %0.3f for input spec: %s', flops / 10**9, input_spec) + return flops + + +class ConfigDictWithAccessRecord(ml_collections.ConfigDict): + """A wrapper for ConfigDicts that records access of any config field. + + ConfigDictWithAccessRecord behaves like a standard ConfigDict, except that it + records access to any config field (including nested instances of + ConfigDictWithAccessRecord). This allows testing for unused config fields. + + Example usage: + + def test_config_access(self): + with mock.patch('configs.my_config.ml_collections.ConfigDict', + test_utils.ConfigDictWithAccessRecord): + config = config_module.get_config() + config.reset_access_record() # Resets previous access records. + ... # Code that uses config. + self.assertEmpty(config.get_not_accessed()) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.reset_access_record() + + def __getitem__(self, key: str): + self._access_record.add(key) + return super().__getitem__(key) + + def reset_access_record(self): + """Resets the record of config field accesses.""" + for value in self._fields.values(): + if isinstance(value, type(self)): + value.reset_access_record() + # object.__setattr__ avoids triggering ConfigDict's __getattr__: + object.__setattr__(self, '_access_record', set()) + + def get_not_accessed(self, prefix: str = 'config') -> Set[str]: + """Returns the set of fields that were not accessed since the last reset.""" + not_accessed = set() + for key, value in self._fields.items(): + path = f'{prefix}.{key}' + if isinstance(value, type(self)): + not_accessed |= value.get_not_accessed(prefix=path) + else: + if key not in self._access_record and key != '_access_record': + not_accessed.add(path) + return not_accessed + + +class DummyExecutor(futures.Executor): + """A mock executor that operates serially. + + Useful for debugging. + + Example usage: + + # Runs concurrently, difficult to debug: + pool = futures.ThreadPoolExecutor(max_workers=max_workers) + pool.submit(my_function) + + # For debugging: + pool = DummyExecutor() + pool.submit(my_function) # Will block and run serially. + """ + + def __init__(self): + self._shutdown = False + self._shutdown_lock = threading.Lock() + + def submit(self, fn: Callable[..., Any], *args, **kwargs) -> futures.Future: # pylint: disable=g-bare-generic + with self._shutdown_lock: + if self._shutdown: + raise RuntimeError('Cannot schedule new futures after shutdown.') + + future = futures.Future() + try: + result = fn(*args, **kwargs) + except BaseException as e: # pylint: disable=broad-except + future.set_exception(e) + else: + future.set_result(result) + return future + + def shutdown(self, wait: bool = True): # pytype: disable=signature-mismatch # overriding-parameter-name-checks + with self._shutdown_lock: + self._shutdown = True + + +class StepTraceContextHelper: + """Helper class to use jax.profiler.StepTraceAnnotation. + + This will cause a "name" event to show up on the trace timeline if the + event occurs while the process is being traced by TensorBoard. In addition, + if using accelerators, the device trace timeline will also show a "name" + event. Note that "step_num" can be set as a keyword argument to pass the + global step number to the profiler. See jax.profiler.StepTraceAnnotation. + + """ + + def __init__(self, name: str, init_step_num: int): + self.name = name + self.step_num = init_step_num + self.context = None + + def __enter__(self): + self.context = jax.profiler.StepTraceAnnotation( + self.name, step_num=self.step_num + ) + self.step_num += 1 + self.context.__enter__() + return self + + def __exit__(self, exc_type, exc_value, tb): + assert self.context is not None, 'Exited context without entering.' + self.context.__exit__(exc_type, exc_value, tb) + self.context = None + + def next_step(self): + if self.context is None: + raise ValueError('Must call next_step() within a context.') + self.__exit__(None, None, None) + self.__enter__() diff --git a/scenic/common_lib/export_utils.py b/scenic/common_lib/export_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd6d87064c701ae7f553aed2ecbfbd2f4f22db2 --- /dev/null +++ b/scenic/common_lib/export_utils.py @@ -0,0 +1,187 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for exporting JAX models to Tensorflow SavedModels.""" + +from typing import Any, Callable, Sequence, Optional, Union + +from jax.experimental import jax2tf +import tensorflow as tf +import tree as dm_tree + +# JAX team is working on type annotation for pytree: +# https://github.com/jax-ml/jax/issues/1555 +# A PyTree is a nested dictionary where the leaves are `jnp.ndarray`. +# TODO(aarnab): Fix type annotation once ready. +PyTree = Any + + +def convert_and_save_model( + jax_fn: Callable[[PyTree, PyTree], PyTree], + params: PyTree, + model_dir: str, + *, + input_signatures: Union[ + Sequence[tf.TensorSpec], + Sequence[Sequence[tf.TensorSpec]], + Sequence[dict[str, tf.TensorSpec]], + ], + polymorphic_shapes: Optional[ + Union[str, jax2tf.PolyShape, dict[str, str]] + ] = None, + with_gradient: bool = False, + enable_xla: bool = True, + compile_model: bool = True, + saved_model_options: Optional[tf.saved_model.SaveOptions] = None, + native_serialization: Optional[str | bool] = "default", + native_serialization_platforms: Sequence[str] | None = ("cpu", "tpu")): + """Converts a JAX function and saves a SavedModel. + + We assume that the JAX model consists of a prediction function and trained + parameters, and the computation graph of the function is saved separately from + the parameters. Saving the graph separately from the parameters reduces + the size of the Tensorflow `GraphDef`, and enables finetuning of model + parameters too. + + To use this function, a JAX model must be converted to a function of two + arguments, the model parameters and the input. + For a Scenic model, this corresponds to: + ``` + params = train_state.optimizer.target + flax_model = model.flax_model + def _predict_fn(params, input_data): + return flax_model.apply({'params': params}, input_data, train=False) + ``` + + Args: + jax_fn: A JAX function taking two arguments, the parameters and the inputs. + Both arguments may be (nested) tuples/lists/dictionaries of `np.ndarray`. + It is necessary to be able to JIT-compile this function (ie run + `jax.jit` on it). + params: The parameters, to be used as first argument for `jax_fn`. These + must be (nested) tuples/lists/dictionaries of `np.ndarray`, and will be + saved as the variables of the SavedModel. + model_dir: The directory where the model should be saved. + input_signatures: The input signatures for the second argument of `jax_fn` + (the input). A signature must be a `tensorflow.TensorSpec` instance, or a + (nested) tuple/list/dictionary thereof with a structure matching the + second argument of `jax_fn`. The first input_signature will be saved as + the default serving signature. The additional signatures will be used + only to ensure that the `jax_fn` is traced and converted to TF for the + corresponding input shapes. + polymorphic_shapes: If given then it will be used as the + `polymorphic_shapes` argument to `jax2tf.convert` for the second parameter + of `jax_fn`. In this case, a single `input_signatures` is supported, and + should have `None` in the polymorphic dimensions. This is required, for + example, to have models with dynamic batch sizes. + with_gradient: Whether the SavedModel should support gradients. If `True`, + then a custom gradient is saved. If `False`, then a + `tf.raw_ops.PreventGradient` is saved to error if a gradient is attempted. + (At the moment due to a bug in SavedModel, custom gradients are not + supported.) + enable_xla: Whether the jax2tf converter is allowed to use TF XLA ops. If + `False`, the conversion tries harder to use purely TF ops and raises an + exception if it is not possible. + compile_model: Use TensorFlow jit_compiler on the SavedModel. This + is needed if the SavedModel will be used for TensorFlow serving. + saved_model_options: Options to pass to `savedmodel.save`. + native_serialization: Serialize the JAX function natively to + StableHLO with compatibility guarantees. This makes it easier to have + confidence that the code executed when calling this function from + TensorFlow is exactly the same as JAX would run natively. See + jax2tf.convert() for details. + native_serialization_platforms: When the "native_serialization" flag is + used, the platforms that it will be serialised to. Must be a tuple of + strings, including a subset of: ['cpu', 'cuda', 'rocm', 'tpu']. + 'None', specifies the JAX default backend on the machine where the + lowering is done. + + Raises: + ValueError: If at least one input signature is not defined. However, if + `polymorphic_shapes` is given, then only one input signature is supported. + """ + if not input_signatures: + raise ValueError("At least one input_signature must be given.") + if polymorphic_shapes is not None and len(input_signatures) > 1: + raise ValueError("For shape-polymorphic conversion a single " + "input_signature is supported.") + tf_fn = jax2tf.convert( + jax_fn, + with_gradient=with_gradient, + polymorphic_shapes=[None, polymorphic_shapes], + enable_xla=enable_xla, + native_serialization=native_serialization, + native_serialization_platforms=native_serialization_platforms) + + def get_tf_variable(path, param): + return tf.Variable(param, trainable=with_gradient, name="/".join(path)) + + param_vars = dm_tree.map_structure_with_path( + # Due to a bug in SavedModel it is not possible to use `tf.GradientTape` + # on a function converted with jax2tf and loaded from SavedModel. Thus, we + # mark the variables as non-trainable to ensure that users of the + # SavedModel will not try to fine tune them. + get_tf_variable, params) + tf_graph = tf.function( + lambda inputs: tf_fn(param_vars, inputs), + autograph=False, + jit_compile=compile_model) + + # This signature is needed for TensorFlow Serving use. + signatures = { + tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + tf_graph.get_concrete_function(input_signatures[0]) + } + + for input_signature in input_signatures[1:]: + # If there are more signatures, trace and cache a TF function for each one. + tf_graph.get_concrete_function(input_signature) + wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars) + + if saved_model_options: + saved_model_options.function_aliases = {"inference_func": tf_graph} + else: + saved_model_options = tf.saved_model.SaveOptions( + function_aliases={"inference_func": tf_graph} + ) + + if with_gradient: + saved_model_options.experimental_custom_gradients = True + + tf.saved_model.save( + wrapper, model_dir, signatures=signatures, options=saved_model_options + ) + + +class _ReusableSavedModelWrapper(tf.train.Checkpoint): + """Wraps a function and its parameters for saving to a SavedModel. + + Implements the interface described at + https://www.tensorflow.org/hub/reusable_saved_models. + """ + + def __init__(self, tf_graph: Callable[[PyTree], PyTree], param_vars: PyTree): + """Constructor. + + Args: + tf_graph: A `tf.function` taking one argument (the inputs), which can be + be tuples/lists/dictionaries of `np.ndarray` or tensors. The function + may have references to the `tf.Variables` in `param_vars`. + param_vars: The parameters, as tuples/lists/dictionaries of + `tf.Variable`, to be saved as the variables of the SavedModel. + """ + super().__init__() + self.variables = tf.nest.flatten(param_vars) + self.trainable_variables = [v for v in self.variables if v.trainable] + self.__call__ = tf_graph diff --git a/scenic/common_lib/image_utils.py b/scenic/common_lib/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..be4866d52a16a79393b5b84f2431301eb770d083 --- /dev/null +++ b/scenic/common_lib/image_utils.py @@ -0,0 +1,99 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Image-related utility functions.""" + +from typing import Optional + +import jax +import jax.numpy as jnp +import numpy as np +from PIL import Image + + +def compress_masks(mask_probs, k=3): + """At each pixel, stores the largest k probabilities and their indices.""" + if mask_probs.ndim == 5: + mask_probs = jnp.squeeze(mask_probs, axis=-1) # Remove channel dim. + # Input shape should be [b, num_queries, out_h, out_w]. + assert mask_probs.ndim == 4, f'Expected 4-D input, got {mask_probs.shape}' + mask_probs = jnp.transpose(mask_probs, [0, 2, 3, 1]) + vals, inds = jax.lax.top_k(mask_probs, k=k) + # Back to [b, k, out_h, out_w] + vals = jnp.transpose(vals, [0, 3, 1, 2]) + inds = jnp.transpose(inds, [0, 3, 1, 2]) + return vals, inds + + +def decompress_masks(compressed_masks, num_queries): + """Reconstructs the uncompressed mask representation.""" + vals, inds = compressed_masks + b, _, h, w = vals.shape + mask_probs = np.zeros((b, num_queries, h, w)) + ib, _, ih, iw = np.meshgrid( + range(b), range(1), range(h), range(w), indexing='ij') + mask_probs[ib, inds, ih, iw] = vals + return mask_probs + + +def resize_pil(image_or_batch: np.ndarray, + *, + out_h: int, + out_w: int, + num_batch_dims: Optional[int] = None, + method: str = 'linear') -> np.ndarray: + """Resizes an image or batch of images using PIL. + + This function handles images with or without channel dimension, but requires + any leading batch dimensions to be specified explicitly to avoid ambiguities. + + Args: + image_or_batch: Image or batch of images. + out_h: Image height after resizing. + out_w: Image width after resizing. + num_batch_dims: Number of leading dimensions that are to be treated as batch + dimensions, e.g. 0 for single images or 1 for simple batches. If None, the + input is assumed to be a single image. + method: String indicating the resizing method. One of "linear" or "nearest". + + Returns: + Resized image or batch of images. + """ + if num_batch_dims is None: + num_batch_dims = 0 + if image_or_batch.ndim > 3 or (image_or_batch.ndim == 3 and + image_or_batch.shape[-1] not in [3, 4]): + raise ValueError('If a batch of images is supplied, num_batch_dims must ' + 'be specified.') + + if method == 'linear': + resample = Image.Resampling.BILINEAR + elif method == 'nearest': + resample = Image.Resampling.NEAREST + elif method == 'lanczos': + resample = Image.Resampling.LANCZOS + else: + raise NotImplementedError(f'Method not implemented: {method}') + + batch_dims = image_or_batch.shape[:num_batch_dims] + image_dims = image_or_batch.shape[num_batch_dims:] + batch = np.reshape(image_or_batch, (-1,) + image_dims) + + pil_size = [out_w, out_h] + resized = np.stack([ + np.asarray(Image.fromarray(image).resize(pil_size, resample)) # pytype: disable=wrong-arg-types # pillow-102-upgrade + for image in batch + ]) + + return np.reshape(resized, batch_dims + resized.shape[1:]) diff --git a/scenic/common_lib/tests/__init__.py b/scenic/common_lib/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/common_lib/tests/test_common_utils.py b/scenic/common_lib/tests/test_common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/common_lib/tests/test_debug_utils.py b/scenic/common_lib/tests/test_debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/common_lib/tests/test_image_utils.py b/scenic/common_lib/tests/test_image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/common_lib/tests/test_video_utils.py b/scenic/common_lib/tests/test_video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/common_lib/video_utils.py b/scenic/common_lib/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f705bb70e27e6edca0e6d933a308d61d9e29558 --- /dev/null +++ b/scenic/common_lib/video_utils.py @@ -0,0 +1,36 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Video-related utility functions.""" + +import jax.numpy as jnp + + +def sample_frames_uniformly(x: jnp.ndarray, + n_sampled_frames: int) -> jnp.ndarray: + """Sample frames from the input video.""" + if x.ndim != 5: + raise ValueError('Input shape should be [bs, t, h, w, c].') + num_frames = x.shape[1] + if n_sampled_frames < num_frames: + t_start_idx = num_frames / (n_sampled_frames + 1) + t_step = t_start_idx + else: + t_start_idx = 0 + t_step = 1 + t_end_idx = num_frames + temporal_indices = jnp.arange(t_start_idx, t_end_idx, t_step) + temporal_indices = jnp.round(temporal_indices).astype(jnp.int32) + temporal_indices = jnp.minimum(temporal_indices, num_frames - 1) + return x[:, temporal_indices] # [n, t_s, in_h, in_w, c] diff --git a/scenic/dataset_lib/__init__.py b/scenic/dataset_lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/__pycache__/__init__.cpython-310.pyc b/scenic/dataset_lib/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/__pycache__/dataset_utils.cpython-310.pyc b/scenic/dataset_lib/__pycache__/dataset_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/__pycache__/datasets.cpython-310.pyc b/scenic/dataset_lib/__pycache__/datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/bair_dataset.py b/scenic/dataset_lib/bair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4fc08d7440317a83656adcc7332348dbbb079c --- /dev/null +++ b/scenic/dataset_lib/bair_dataset.py @@ -0,0 +1,233 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for the BAIR Robot dataset.""" + +import functools +from typing import Optional + +from absl import logging +from dmvr import processors +from flax import jax_utils +import jax.numpy as jnp +from scenic.dataset_lib import dataset_utils +from scenic.dataset_lib import datasets +import tensorflow as tf + + +def preprocess_train_example(example, + camera_name='image_main', + dtype=tf.float32, + zero_centering=True): + """Preprocesses the given video. + + Args: + example: dict; Example that has an 'image_main'. + camera_name: Name of the image sequence to use. + dtype: Tensorflow data type; Data type of the image. + zero_centering: If True, frames are normalized to values in [-1, 1]. + If False, values in [0, 1]. + + Returns: + dict; Example that has an 'inputs'. + """ + frames = example[camera_name] + frames = processors.normalize_image(frames, zero_centering, dtype) + return {'inputs': frames} + + +def augment_train_example(example, num_frames=30, stride=1): + """Augment the given video for training. + + Args: + example: dict; Example that has an 'inputs'. + num_frames: Number of frames per subclip. + stride: Temporal stride to sample frames. + + Returns: + dict; Example that has an 'inputs'. + """ + frames = example['inputs'] + frames = processors.sample_sequence(frames, num_frames, True, stride) + frames = processors.random_flip_left_right(frames) + return {'inputs': frames} + + +def preprocess_eval_example(example, + camera_name='image_main', + dtype=tf.float32, + num_frames=30, + stride=1, + num_clips=1, + zero_centering=True): + """Preprocesses the given video for evaluation. + + Args: + example: dict; Example that has an 'inputs'. + camera_name: Name of the image sequence to use. + dtype: Tensorflow data type; Data type of the image. + num_frames: Number of frames per subclip. + stride: Temporal stride to sample frames. + num_clips: Linearly spaced clips to sample from each example. + zero_centering: If True, frames are normalized to values in [-1, 1]. + If False, values in [0, 1]. + + Returns: + dict; Example that has an 'inputs'. + """ + frames = example[camera_name] + frames = processors.normalize_image(frames, zero_centering, dtype) + clips = processors.sample_linspace_sequence(frames, num_clips, num_frames, + stride) + return {'inputs': clips} + + +def postprocess_eval_batch(batch, num_frames=30): + """Postprocesses the given batch for evaluation. + + Reshapes the batch from [bs, num_clips * num_frames, ...] into + [bs * num_clips, num_frames, ...]. + + Args: + batch: dict; Batch that has an 'inputs'. + num_frames: Number of frames per subclip. + Returns: + dict; Example that has an 'inputs'. + """ + batch_clips = batch['inputs'] + batch_clips = tf.reshape(batch_clips, + (-1, num_frames, *batch_clips.shape[2:])) + return {'inputs': batch_clips} + + +@datasets.add_dataset('bair') +def get_dataset(*, + batch_size, + eval_batch_size, + num_shards, + dtype_str='float32', + shuffle_seed=0, + rng=None, + dataset_configs=None, + dataset_service_address: Optional[str] = None): + """Returns generators for the BAIR train, validation, and test set. + + Args: + batch_size: int; Determines the train batch size. + eval_batch_size: int; Determines the evaluation batch size. + num_shards: int; Number of shards --> batch shape: [num_shards, bs, ...]. + dtype_str: Data type of the image (e.g. 'float32'). + shuffle_seed: int; Seed for shuffling the training data. + rng: JAX rng key, which can be used for augmentation, shuffling, etc. + dataset_configs: dict; Dataset specific configurations. + dataset_service_address: If set, will distribute the training dataset using + the given tf.data service at the given address. + + Returns: + A dataset_utils.Dataset() which includes a train_iter, a valid_iter, + a test_iter, and a dict of meta_data. + """ + del rng + dtype = getattr(tf, dtype_str) + dataset_configs = dataset_configs or {} + camera_name = dataset_configs.get('camera_name', 'image_main') + num_frames = dataset_configs.get('num_frames', 30) + stride = dataset_configs.get('stride', 1) + zero_centering = dataset_configs.get('zero_centering', True) + num_eval_clips = dataset_configs.get('num_eval_clips', 1) + shuffle_buffer_size = dataset_configs.get('shuffle_buffer_size', None) + preprocess_train = functools.partial( + preprocess_train_example, + camera_name=camera_name, + dtype=dtype, + zero_centering=zero_centering) + augment_train = functools.partial( + augment_train_example, num_frames=num_frames, stride=stride) + preprocess_eval = functools.partial( + preprocess_eval_example, + camera_name=camera_name, + dtype=dtype, + num_frames=num_frames, + stride=stride, + num_clips=num_eval_clips, + zero_centering=zero_centering) + if num_eval_clips > 1: + postprocess_eval = functools.partial( + postprocess_eval_batch, num_frames=num_frames) + else: + postprocess_eval = None + + logging.info('Loading train split of the BAIR dataset.') + train_ds, _ = dataset_utils.load_split_from_tfds( + 'bair_robot_pushing_small', + batch_size, + split='train', + preprocess_example=preprocess_train, + augment_train_example=augment_train, + shuffle_buffer_size=shuffle_buffer_size, + shuffle_seed=shuffle_seed) + + if dataset_service_address: + if shuffle_seed is not None: + raise ValueError('Using dataset service with a random seed causes each ' + 'worker to produce exactly the same data. Add ' + 'config.shuffle_seed = None to your config if you ' + 'want to run with dataset service.') + logging.info('Using the tf.data service at %s', dataset_service_address) + train_ds = dataset_utils.distribute(train_ds, dataset_service_address) + + logging.info('Loading test split of the BAIR dataset.') + eval_ds, _ = dataset_utils.load_split_from_tfds( + 'bair_robot_pushing_small', + eval_batch_size, + split='test', + preprocess_example=preprocess_eval, + postprocess_batch=postprocess_eval) + + maybe_pad_batches_train = functools.partial( + dataset_utils.maybe_pad_batch, train=True, batch_size=batch_size) + maybe_pad_batches_eval = functools.partial( + dataset_utils.maybe_pad_batch, + train=False, + batch_size=eval_batch_size * num_eval_clips) + shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards) + + train_iter = iter(train_ds) + train_iter = map(dataset_utils.tf_to_numpy, train_iter) + train_iter = map(maybe_pad_batches_train, train_iter) + train_iter = map(shard_batches, train_iter) + if dataset_configs.get('prefetch_to_device'): + # Async bind batch to device which speeds up training. + train_iter = jax_utils.prefetch_to_device( + train_iter, dataset_configs.get('prefetch_to_device')) + + eval_iter = iter(eval_ds) + eval_iter = map(dataset_utils.tf_to_numpy, eval_iter) + eval_iter = map(maybe_pad_batches_eval, eval_iter) + eval_iter = map(shard_batches, eval_iter) + + input_shape = (-1, num_frames, 64, 64, 3) + num_train_examples = dataset_utils.get_num_examples( + 'bair_robot_pushing_small', 'train') + num_eval_examples = dataset_utils.get_num_examples('bair_robot_pushing_small', + 'test') * num_eval_clips + meta_data = { + 'num_classes': None, + 'input_shape': input_shape, + 'num_train_examples': num_train_examples, + 'num_eval_examples': num_eval_examples, + 'input_dtype': getattr(jnp, dtype_str), + 'target_is_onehot': False, + } + return dataset_utils.Dataset(train_iter, eval_iter, None, meta_data) diff --git a/scenic/dataset_lib/big_transfer/README.md b/scenic/dataset_lib/big_transfer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/big_transfer/__init__.py b/scenic/dataset_lib/big_transfer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/big_transfer/bit.py b/scenic/dataset_lib/big_transfer/bit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/big_transfer/builder.py b/scenic/dataset_lib/big_transfer/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/big_transfer/registry.py b/scenic/dataset_lib/big_transfer/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/cifar10_dataset.py b/scenic/dataset_lib/cifar10_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d592b33b6e33e328a59d22011c8b151f01741ff1 --- /dev/null +++ b/scenic/dataset_lib/cifar10_dataset.py @@ -0,0 +1,180 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for the CIFAR10 dataset.""" + +import functools +from typing import Optional + +from absl import logging +import jax.numpy as jnp +from scenic.dataset_lib import dataset_utils +from scenic.dataset_lib import datasets +import tensorflow as tf + +HEIGHT = 32 +WIDTH = 32 +NUM_CHANNELS = 3 + +# Computed from the training set by taking the per-channel mean/std-dev +# over sample, height and width axes of all training samples. +MEAN_RGB = [0.4914 * 255, 0.4822 * 255, 0.4465 * 255] +STDDEV_RGB = [0.2470 * 255, 0.2435 * 255, 0.2616 * 255] + + +def preprocess_example(example, dtype=tf.float32): + """Preprocesses the given example. + + Args: + example: dict; Example that has an 'image' and a 'label'. + dtype: Tensorflow data type; Data type of the image. + + Returns: + A preprocessed example. + """ + image = tf.cast(example['image'], dtype=dtype) + if dtype not in [tf.int32, tf.int64, tf.uint32, tf.uint64]: + mean_rgb = tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=dtype) + std_rgb = tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=dtype) + image = (image - mean_rgb) / std_rgb + return {'inputs': image, 'label': example['label']} + + +def augment_example(example, dtype=tf.float32, data_augmentations=None): + """Apply data augmentation on the given training example. + + Args: + example: dict; Example that has an 'image' and a 'label'. + dtype: Tensorflow data type; Data type of the image. + data_augmentations: list(str); Types of data augmentation applied on + training data. + + Returns: + An augmented training example. + """ + image = tf.cast(example['inputs'], dtype=dtype) + if data_augmentations is not None: + if 'cifar_default' in data_augmentations: + image = dataset_utils.augment_random_crop_flip( + image, HEIGHT, WIDTH, NUM_CHANNELS, crop_padding=4, flip=True) + image = tf.cast(image, dtype=dtype) + return {'inputs': image, 'label': example['label']} + + +@datasets.add_dataset('cifar10') +def get_dataset(*, + batch_size, + eval_batch_size, + num_shards, + dtype_str='float32', + shuffle_seed=0, + rng=None, + dataset_configs=None, + dataset_service_address: Optional[str] = None): + """Returns generators for the CIFAR10 train, validation, and test set. + + Args: + batch_size: int; Determines the train batch size. + eval_batch_size: int; Determines the evaluation batch size. + num_shards: int; Number of shards --> batch shape: [num_shards, bs, ...]. + dtype_str: Data type of the image (e.g. 'float32'). + shuffle_seed: int; Seed for shuffling the training data. + rng: JAX rng key, which can be used for augmentation, shuffling, etc. + dataset_configs: dict; Dataset specific configurations. + dataset_service_address: If set, will distribute the training dataset using + the given tf.data service at the given address. + + Returns: + A dataset_utils.Dataset() which includes a train_iter, a valid_iter, + a test_iter, and a dict of meta_data. + """ + del rng + dataset_configs = dataset_configs or {} + data_augmentations = dataset_configs.get('data_augmentations', []) + # alwayse include the default data augmentation + data_augmentations.append('cifar_default') + for da in data_augmentations: + if da not in ['mixup', 'cifar_default']: + raise ValueError(f'Data augmentation type {da} is not yet supported ' + f'in the CIFAR dataset.') + + dtype = getattr(tf, dtype_str) + target_is_onehot = False + preprocess_ex = functools.partial(preprocess_example, dtype=dtype) + + logging.info('Loading train split of the CIFAR10 dataset.') + augment_ex = functools.partial( + augment_example, dtype=dtype, data_augmentations=data_augmentations) + train_ds, train_ds_info = dataset_utils.load_split_from_tfds( + 'cifar10', + batch_size, + split='train', + preprocess_example=preprocess_ex, + augment_train_example=augment_ex, + shuffle_seed=shuffle_seed) + + if dataset_service_address: + if shuffle_seed is not None: + raise ValueError('Using dataset service with a random seed causes each ' + 'worker to produce exactly the same data. Add ' + 'config.shuffle_seed = None to your config if you ' + 'want to run with dataset service.') + logging.info('Using the tf.data service at %s', dataset_service_address) + train_ds = dataset_utils.distribute(train_ds, dataset_service_address) + + logging.info('Loading test split of the CIFAR10 dataset.') + eval_ds, _ = dataset_utils.load_split_from_tfds( + 'cifar10', + eval_batch_size, + split='test', + preprocess_example=preprocess_ex) + + maybe_pad_batches_train = functools.partial( + dataset_utils.maybe_pad_batch, train=True, batch_size=batch_size) + maybe_pad_batches_eval = functools.partial( + dataset_utils.maybe_pad_batch, train=False, batch_size=eval_batch_size) + num_classes = train_ds_info.features['label'].num_classes + target_to_one_hot_batches = functools.partial( + dataset_utils.target_to_one_hot, num_classes=num_classes) + shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards) + mixup_batches = functools.partial(dataset_utils.mixup, alpha=1.0) + + train_iter = iter(train_ds) + train_iter = map(dataset_utils.tf_to_numpy, train_iter) + train_iter = map(maybe_pad_batches_train, train_iter) + if 'mixup' in data_augmentations: + train_iter = map(target_to_one_hot_batches, train_iter) + train_iter = map(mixup_batches, train_iter) + target_is_onehot = True + train_iter = map(shard_batches, train_iter) + + # Note: samples will be dropped if the number of test samples + # (EVAL_IMAGES=10000) is not divisible by the evaluation batch size + eval_iter = iter(eval_ds) + eval_iter = map(dataset_utils.tf_to_numpy, eval_iter) + eval_iter = map(maybe_pad_batches_eval, eval_iter) + if target_is_onehot: + eval_iter = map(target_to_one_hot_batches, eval_iter) + eval_iter = map(shard_batches, eval_iter) + + input_shape = (-1, HEIGHT, WIDTH, NUM_CHANNELS) + meta_data = { + 'num_classes': num_classes, + 'input_shape': input_shape, + 'num_train_examples': dataset_utils.get_num_examples('cifar10', 'train'), + 'num_eval_examples': dataset_utils.get_num_examples('cifar10', 'test'), + 'input_dtype': getattr(jnp, dtype_str), + 'target_is_onehot': target_is_onehot, + } + return dataset_utils.Dataset(train_iter, eval_iter, None, meta_data) diff --git a/scenic/dataset_lib/cityscapes_dataset.py b/scenic/dataset_lib/cityscapes_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a945db0850fc81f0d0730aff092494b75f2a996f --- /dev/null +++ b/scenic/dataset_lib/cityscapes_dataset.py @@ -0,0 +1,363 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for the Cityscapes dataset.""" + +import collections +import functools +from typing import Optional + +from absl import logging +import jax.numpy as jnp +import numpy as np +from scenic.dataset_lib import dataset_utils +from scenic.dataset_lib import datasets +import tensorflow as tf + +# Based on https://github.com/mcordts/cityscapesScripts +CityscapesClass = collections.namedtuple( + 'CityscapesClass', + ['name', 'id', 'train_id', 'category', 'category_id', 'has_instances', + 'ignore_in_eval', 'color']) + +CLASSES = [ + CityscapesClass( + 'unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), + CityscapesClass( + 'ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), + CityscapesClass( + 'rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), + CityscapesClass( + 'out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), + CityscapesClass( + 'static', 4, 255, 'void', 0, False, True, (0, 0, 0)), + CityscapesClass( + 'dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), + CityscapesClass( + 'ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), + CityscapesClass( + 'road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), + CityscapesClass( + 'sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), + CityscapesClass( + 'parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), + CityscapesClass( + 'rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), + CityscapesClass( + 'building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), + CityscapesClass( + 'wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), + CityscapesClass( + 'fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), + CityscapesClass( + 'guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), + CityscapesClass( + 'bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), + CityscapesClass( + 'tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), + CityscapesClass( + 'pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), + CityscapesClass( + 'polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), + CityscapesClass( + 'traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), + CityscapesClass( + 'traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), + CityscapesClass( + 'vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), + CityscapesClass( + 'terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), + CityscapesClass( + 'sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), + CityscapesClass( + 'person', 24, 11, 'human', 6, True, False, (220, 20, 60)), + CityscapesClass( + 'rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), + CityscapesClass( + 'car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), + CityscapesClass( + 'truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), + CityscapesClass( + 'bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), + CityscapesClass( + 'caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), + CityscapesClass( + 'trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), + CityscapesClass( + 'train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), + CityscapesClass( + 'motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), + CityscapesClass( + 'bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), + CityscapesClass( + 'license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), +] + +# Number of pixels per Cityscapes class ID in the training set: +PIXELS_PER_CID = { + 7: 3806423808, + 8: 629490880, + 11: 2354443008, + 12: 67089092, + 13: 91210616, + 17: 126753000, + 19: 21555918, + 20: 57031712, + 21: 1647446144, + 22: 119165328, + 23: 415038624, + 24: 126403824, + 25: 13856368, + 26: 725164864, + 27: 27588982, + 28: 24276994, + 31: 24195352, + 32: 10207740, + 33: 42616088 +} + + +def preprocess_example(example, train, dtype=tf.float32, resize=None): + """Preprocesses the given image. + + Args: + example: dict; Example coming from TFDS. + train: bool; Whether to apply training-specific preprocessing or not. + dtype: Tensorflow data type; Data type of the image. + resize: sequence; [H, W] to which image and labels should be resized. + + Returns: + An example dict as required by the model. + """ + image = dataset_utils.normalize(example['image_left'], dtype) + mask = example['segmentation_label'] + + # Resize test images (train images are cropped/resized during augmentation): + if not train: + if resize is not None: + image = tf.image.resize(image, resize, 'bilinear') + mask = tf.image.resize(mask, resize, 'nearest') + + image = tf.cast(image, dtype) + mask = tf.cast(mask, dtype) + mask = tf.squeeze(mask, axis=2) + return {'inputs': image, 'label': mask} + + +def augment_example( + example, dtype=tf.float32, resize=None, **inception_crop_kws): + """Augments the given train image. + + Args: + example: dict; Example coming from TFDS. + dtype: Tensorflow data type; Data type of the image. + resize: sequence; [H, W] to which image and labels should be resized. + **inception_crop_kws: Keyword arguments passed on to + inception_crop_with_mask. + + Returns: + An example dict as required by the model. + """ + image = example['inputs'] + mask = example['label'][..., tf.newaxis] + + # Random crop and resize ("Inception crop"): + image, mask = dataset_utils.inception_crop_with_mask( + image, + mask, + resize_size=image.shape[-3:-1] if resize is None else resize, + **inception_crop_kws) + + # Random LR flip: + seed = tf.random.uniform(shape=[2], maxval=2**31 - 1, dtype=tf.int32) + image = tf.image.stateless_random_flip_left_right(image, seed) + mask = tf.image.stateless_random_flip_left_right(mask, seed) + + image = tf.cast(image, dtype) + mask = tf.cast(mask, dtype) + mask = tf.squeeze(mask, axis=2) + return {'inputs': image, 'label': mask} + + +def get_post_exclusion_labels(): + """Determines new labels after excluding bad classes. + + See Figure 1 in https://arxiv.org/abs/1604.01685 for which classes are + excluded. Excluded classes get the new label -1. + + Returns: + An array of length num_old_classes, containing new labels. + """ + old_to_new_labels = np.array( + [-1 if c.ignore_in_eval else c.train_id for c in CLASSES]) + assert np.all(np.diff([i for i in old_to_new_labels if i >= 0]) == 1) + return old_to_new_labels + + +def get_class_colors(): + """Returns a [num_classes, 3] array of colors for the model output labels.""" + cm = np.stack([c.color for c in CLASSES if not c.ignore_in_eval], axis=0) + return cm / 255.0 + + +def get_class_names(): + """Returns a list with the class names of the model output labels.""" + return [c.name for c in CLASSES if not c.ignore_in_eval] + + +def get_class_proportions(): + """Returns a [num_classes] array of pixel frequency proportions.""" + p = [PIXELS_PER_CID[c.id] for c in CLASSES if not c.ignore_in_eval] + return np.array(p) / np.sum(p) + + +def exclude_bad_classes(batch, new_labels): + """Adjusts masks and batch_masks to exclude void and rare classes. + + This must be applied after dataset_utils.maybe_pad_batch() because we also + update the batch_mask. Note that the data is already converted to Numpy by + then. + + Args: + batch: dict; Batch of data examples. + new_labels: nd-array; array of length num_old_classes, containing new + labels. + + Returns: + Updated batch dict. + """ + # Convert old labels to new labels: + batch['label'] = new_labels[batch['label'].astype(np.int32)] + + # Set batch_mask to 0 at pixels that have an excluded label: + mask_dtype = batch['batch_mask'].dtype + batch['batch_mask'] = ( + batch['batch_mask'].astype(np.bool_) & (batch['label'] != -1)) + batch['batch_mask'] = batch['batch_mask'].astype(mask_dtype) + + return batch + + +@datasets.add_dataset('cityscapes') +def get_dataset(*, + batch_size, + eval_batch_size, + num_shards, + dtype_str='float32', + shuffle_seed=0, + rng=None, + dataset_configs=None, + dataset_service_address: Optional[str] = None): + """Returns generators for the Cityscapes train, validation, and test set. + + Args: + batch_size: int; Determines the train batch size. + eval_batch_size: int; Determines the evaluation batch size. + num_shards: int; Number of shards --> batch shape: [num_shards, bs, ...]. + dtype_str: Data type of the image (e.g. 'float32'). + shuffle_seed: int; Seed for shuffling the training data. + rng: JAX rng key, which can be used for augmentation, shuffling, etc. + dataset_configs: dict; Dataset specific configurations. + dataset_service_address: If set, will distribute the training dataset using + the given tf.data service at the given address. + + Returns: + A dataset_utils.Dataset() which includes a train_iter, a valid_iter, + a test_iter, and a dict of meta_data. + """ + del rng + dtype = getattr(tf, dtype_str) + dataset_configs = dataset_configs or {} + target_size = dataset_configs.get('target_size', None) + + logging.info('Loading train split of the Cityscapes dataset.') + preprocess_ex_train = functools.partial( + preprocess_example, train=True, dtype=dtype, resize=None) + augment_ex = functools.partial( + augment_example, dtype=dtype, resize=target_size, area_min=30, + area_max=100) + + train_split = dataset_configs.get('train_split', 'train') + train_ds, _ = dataset_utils.load_split_from_tfds( + 'cityscapes', + batch_size, + split=train_split, + preprocess_example=preprocess_ex_train, + augment_train_example=augment_ex, + shuffle_seed=shuffle_seed) + + if dataset_service_address: + if shuffle_seed is not None: + raise ValueError('Using dataset service with a random seed causes each ' + 'worker to produce exactly the same data. Add ' + 'config.shuffle_seed = None to your config if you ' + 'want to run with dataset service.') + logging.info('Using the tf.data service at %s', dataset_service_address) + train_ds = dataset_utils.distribute(train_ds, dataset_service_address) + + logging.info('Loading validation split of the Cityscapes dataset.') + preprocess_ex_eval = functools.partial( + preprocess_example, train=False, dtype=dtype, resize=target_size) + eval_ds, _ = dataset_utils.load_split_from_tfds( + 'cityscapes', eval_batch_size, split='validation', + preprocess_example=preprocess_ex_eval) + + maybe_pad_batches_train = functools.partial( + dataset_utils.maybe_pad_batch, train=True, batch_size=batch_size, + pixel_level=True) + maybe_pad_batches_eval = functools.partial( + dataset_utils.maybe_pad_batch, train=False, batch_size=eval_batch_size, + pixel_level=True) + shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards) + exclude_classes = functools.partial( + exclude_bad_classes, new_labels=get_post_exclusion_labels()) + + train_iter = iter(train_ds) + train_iter = map(dataset_utils.tf_to_numpy, train_iter) + train_iter = map(maybe_pad_batches_train, train_iter) + train_iter = map(exclude_classes, train_iter) + train_iter = map(shard_batches, train_iter) + + eval_iter = iter(eval_ds) + eval_iter = map(dataset_utils.tf_to_numpy, eval_iter) + eval_iter = map(maybe_pad_batches_eval, eval_iter) + eval_iter = map(exclude_classes, eval_iter) + eval_iter = map(shard_batches, eval_iter) + + if target_size is None: + input_shape = (-1, 1024, 2048, 3) + else: + input_shape = (-1,) + tuple(target_size) + (3,) + + meta_data = { + 'num_classes': + len([c.id for c in CLASSES if not c.ignore_in_eval]), + 'input_shape': + input_shape, + 'num_train_examples': + dataset_utils.get_num_examples('cityscapes', train_split), + 'num_eval_examples': + dataset_utils.get_num_examples('cityscapes', 'validation'), + 'input_dtype': + getattr(jnp, dtype_str), + 'target_is_onehot': + False, + 'class_names': + get_class_names(), + 'class_colors': + get_class_colors(), + 'class_proportions': + get_class_proportions(), + } + return dataset_utils.Dataset(train_iter, eval_iter, None, meta_data) diff --git a/scenic/dataset_lib/coco_dataset/__init__.py b/scenic/dataset_lib/coco_dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/coco_dataset/coco_eval.py b/scenic/dataset_lib/coco_dataset/coco_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/coco_dataset/coco_utils.py b/scenic/dataset_lib/coco_dataset/coco_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/dataset_utils.py b/scenic/dataset_lib/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..14d94b2ce547deaa1a254de67ddb4a49d901f59c --- /dev/null +++ b/scenic/dataset_lib/dataset_utils.py @@ -0,0 +1,779 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utils for used by different dataset builders. + +Many of these were originally implemented by: Lucas Beyer, Alex Kolesnikov, +Xiaohua Zhai and other collaborators from Brain ZRH. +""" + +import collections +import dataclasses +import functools +import itertools +from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Union + +from absl import logging +from flax.training import common_utils +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +PyTree = Any +DatasetIterator = Union[Iterator[Any], Dict[str, Iterator[Any]]] +DatasetIteratorProvider = Callable[[], DatasetIterator] + + +@dataclasses.dataclass(frozen=True) +class Dataset: + """Dataset type. + + Each instance of the Dataset has three iterators, train_iter, valid_iter, + and test_iter, that yield a batch, where each batch is a (nested) dict of + numpy arrays. These iterators are created by normally applying these + functions on TFDS instances: + + - dataset_utils.tf_to_numpy -> convert tensors to numpy arrays. + - dataset_utils.maybe_pad_batch -> pad partial batches and create + batch_mask if needed. + - dataset_utils.shard_batches -> shard batch across devices by reshaping + `[bs, ...]` to `[num_local_devices, bs/(num_local_devices), ...]`. + + Beside these iterators, there is a dictionary that stores the metadata + information about the dataset, that can be used for different purposes. + For instance, these fields are used in most of the datasets: + + 'input_shape': Used during compiling and initializing the model. + 'num_train_examples': Used for computing the number of training steps + and controlling the train_iter. + 'num_eval_examples': Same as num_train_examples, but for valid_iter. + 'num_test_examples': Same as num_train_examples, but for test_iter. + 'target_is_onehot': Used in the loss and metric functions. + + Note that each dataset can define its own meta-data field that is used + in the model and/or the trainer, depending on the task. As an example, for + classification tasks, `num_classes` is used for the configuring head of + the model. + """ + train_iter: DatasetIterator | DatasetIteratorProvider | None = None + valid_iter: DatasetIterator | DatasetIteratorProvider | None = None + test_iter: DatasetIterator | DatasetIteratorProvider | None = None + meta_data: Dict[str, Any] = dataclasses.field(default_factory=dict) + + train_ds: Union[tf.data.Dataset, Dict[str, tf.data.Dataset]] | None = None + valid_ds: Union[tf.data.Dataset, Dict[str, tf.data.Dataset]] | None = None + test_ds: Union[tf.data.Dataset, Dict[str, tf.data.Dataset]] | None = None + + +def maybe_pad_batch(batch: Dict[str, PyTree], + train: bool, + batch_size: int, + pixel_level: bool = False, + inputs_key: str = 'inputs', + batch_dim: int = 0) -> Dict[str, jnp.ndarray]: + """Zero pad the batch on the right to the batch_size. + + All leave tensors in the batch pytree will be padded. This function expects + the root structure of the batch pytree to be a dictionary and returns a + dictionary with the same structure (and substructures), additionally with the + key 'batch_mask' added to the root dict, with 1.0 indicating indices which are + true data and 0.0 indicating a padded index. `batch_mask` will be used for + calculating the weighted cross entropy, or weighted accuracy. + + Note that in this codebase, we assume we drop the last partial batch from the + training set, so if the batch is from the training set (i.e. `train=True`), + or when the batch is from the test/validation set, but it is a complete batch, + we *modify* the batch dict by adding an array of ones as the `batch_mask` of + all examples in the batch. Otherwise, we create a new dict that has the padded + patch and its corresponding `batch_mask` array. + + Note that batch_mask can be also used as the label mask (not input mask), for + task that are pixel/token level. This is simply done by applying the mask we + make for padding the partial batches on top of the existing label mask. + + Args: + batch: A dictionary containing a pytree. If `inputs_key` is not set, we use + the first leave to get the current batch size. Otherwise, the tensor + mapped with `inputs_key` at the root dictionary is used. + train: if the batch is from the training data. In that case, we drop + the last (incomplete) batch and thus don't do any padding. + batch_size: All arrays in the dict will be padded to have first + dimension equal to desired_batch_size. + pixel_level: If True, this will create a pixel-level (instead of + example-level) mask, e.g. for segmentation models. + inputs_key: Indicating the key used for the input that we do batch padding + based on. + batch_dim: Batch dimension. The default is 0, but it can be different + if a sharded batch is given. + + Returns: + A dictionary mapping the same keys to the padded batches. Additionally, we + add a key representing weights, to indicate how the batch was padded. + """ + assert batch_dim >= 0, f'batch_dim=={batch_dim} is expected to be >= 0' + if inputs_key is None: + sample_tensor = jax.tree_util.tree_leaves(batch)[0] + else: + sample_tensor = batch[inputs_key] + if sample_tensor.shape[batch_dim] > batch_size: + raise ValueError( + f'The indicated target batch_size is {batch_size}, but ' + 'the size of the current batch is larger than that: ' + f'{sample_tensor.shape[batch_dim]}.' + ) + batch_pad = batch_size - sample_tensor.shape[batch_dim] + + if pixel_level: + unpadded_mask_shape = sample_tensor.shape[:-1] + else: + assert 'batch_mask' not in batch, ( + 'When the labels of the task are not pixel-level, batch_mask should ' + 'not be already present in the batch.') + unpadded_mask_shape = sample_tensor.shape[:batch_dim + 1] + + if train and batch_pad != 0: + raise ValueError('In this codebase, we assumed that we always drop the ' + 'last partial batch of the train set. Please use ' + '` drop_remainder=True` for the training set.') + # Most batches will not need padding, so we quickly return to avoid slowdown. + if train or batch_pad == 0: + if 'batch_mask' not in batch: + batch['batch_mask'] = np.ones(unpadded_mask_shape, dtype=np.float32) + return batch + + def zero_pad(array): + pad_with = ([(0, 0)] * batch_dim + [(0, batch_pad)] + + [(0, 0)] * (array.ndim - batch_dim - 1)) + return np.pad(array, pad_with, mode='constant') + + padded_batch = jax.tree_util.tree_map(zero_pad, batch) + padded_batch_mask = zero_pad(np.ones(unpadded_mask_shape, dtype=np.float32)) + if 'batch_mask' in padded_batch: + padded_batch['batch_mask'] *= padded_batch_mask + else: + padded_batch['batch_mask'] = padded_batch_mask + return padded_batch + + +def shard(pytree, n_devices=None): + """Reshapes all arrays in the pytree to add a leading n_devices dimension. + + To be used for pmap-based data-parallelism. + + Note: We assume that all arrays in the pytree have leading dimension divisible + by n_devices and reshape (host_batch_size, height, width, channel) to + (local_devices, device_batch_size, height, width, channel). + + Args: + pytree: A pytree of arrays to be sharded. + n_devices: If None, this will be set to jax.local_device_count(). + + Returns: + Sharded data. + """ + if n_devices is None: + n_devices = jax.local_device_count() + + def _shard_array(array): + return array.reshape((n_devices, -1) + array.shape[1:]) + + return jax.tree_util.tree_map(_shard_array, pytree) + + +def shard_jit( + data: PyTree, + global_devices: np.ndarray, + mesh_axis: tuple[str, ...] = ('devices',), +) -> PyTree: + """Shards data for use in jit-based pipelines. + + Note that the order of global devices for sharding data is important and + should be compatible with device order used in the rest of the trainer for + models params, state, etc. + + Based on: + https://github.com/google-research/big_vision/blob/main/big_vision/input_pipeline.py. + + Args: + data: PyTree of data + global_devices: List of global devices to shard over. + mesh_axis: Specifies axis separately. + + Returns: + Sharded data. + """ + + def _shard_array(x): + mesh = jax.sharding.Mesh(global_devices, mesh_axis) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(mesh_axis) + ) + local_ds = mesh.local_devices + + x = np.asarray(memoryview(x)) # No-copy: http://shortn/_KM5whIEtWI + xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds) + + global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:]) + return jax.make_array_from_single_device_arrays(global_shape, sharding, xs) + + return jax.tree_util.tree_map(_shard_array, data) + + +def prefetch_iterator(it, n): + """Prefetches batches from an iterator. + + Runs iterator `it` ahead for `n` steps. + + Adapted from big_vision: + https://github.com/google-research/big_vision/blob/main/big_vision/input_pipeline.py. + + Args: + it: Iterator + n: Number of steps to prefect for. + + Yields: + Original items from the iterator which have been prefetched. + """ + if not n: + yield from it + return + queue = collections.deque() + + def enqueue(n_steps): # Enqueues *up to* `n` elements from the iterator. + for data in itertools.islice(it, n_steps): + queue.append(data) + + enqueue(n) # Fill up the buffer. + while queue: + yield queue.popleft() + enqueue(1) + + +def unshard(pytree): + """Reshapes all arrays in the pytree from [ndev, bs, ...] to [host_bs, ...]. + + Args: + pytree: A pytree of arrays to be sharded. + + Returns: + Sharded data. + """ + + def _unshard_array(array): + ndev, bs = array.shape[:2] + return array.reshape((ndev * bs,) + array.shape[2:]) + + return jax.tree_util.tree_map(_unshard_array, pytree) + + +def tf_to_numpy(batch): + """Convert an input batch from tf Tensors to numpy arrays. + + Args: + batch: dict; A dictionary that has items in a batch: image and labels. + + Returns: + Numpy arrays of the given tf Tensors. + """ + # Use _numpy() for zero-copy conversion between TF and NumPy. + convert_data = lambda x: x._numpy() # pylint: disable=protected-access + return jax.tree_util.tree_map(convert_data, batch) + + +def augment_random_crop_flip(image, + height=None, + width=None, + num_channels=None, + crop_padding=4, + flip=True): + """Augment small image with random crop and h-flip. + + Args: + image: Input image to augment. + height: int; Height of the target image. + width: int; Width of the target image. + num_channels: int; Number of channels of the target image. + crop_padding: int; Random crop range. + flip: bool; If True perform random horizontal flip. + + Returns: + Augmented image. + """ + h, w, c = image.get_shape().as_list() + height = height or h + width = width or w + num_channels = num_channels or c + + assert crop_padding >= 0 + if crop_padding > 0: + # Pad with reflection padding + # (See https://arxiv.org/abs/1605.07146) + # Section 3. + image = tf.pad(image, [[crop_padding, crop_padding], + [crop_padding, crop_padding], [0, 0]], 'REFLECT') + + # Randomly crop a [HEIGHT, WIDTH] section of the image. + image = tf.image.random_crop(image, [height, width, num_channels]) + + if flip: + # Randomly flip the image horizontally. + image = tf.image.random_flip_left_right(image) + + return image + + +def normalize(image, dtype=tf.float32): + """Normalizes the value of pixels in the given image. + + Args: + image: `Tensor` representing an image binary of arbitrary size. + dtype: Tensorflow data type, Data type of the image. + + Returns: + A normalized image `Tensor`. + """ + image = tf.cast(image, dtype=dtype) + if dtype not in [tf.int32, tf.int64, tf.uint32, tf.uint64]: + image /= tf.constant(255.0, shape=[1, 1, 1], dtype=dtype) + return image + + +def load_split_from_tfds(dataset_name, + batch_size, + split, + data_dir=None, + preprocess_example=None, + augment_train_example=None, + postprocess_batch=None, + shuffle_buffer_size=None, + shuffle_seed=0, + cache=True, + **kwargs): + """Loads a split from a dataset using TensorFlow Datasets. + + Args: + dataset_name: str; Name of the dataset to be used to load from tfds. + batch_size: int; The batch size returned by the data pipeline. + split: str; Name of the split to be loaded. + data_dir: str; Data directory. + preprocess_example: function; A function that given an example, returns the + preprocessed example. Note that the preprocessing is done BEFORE caching + to re-use them. + augment_train_example: A function that given a train example returns the + augmented example. Note that this function is applied AFTER caching and + repeat to get true randomness. + postprocess_batch: function; A function that given a batch, returns the + postprocessed batch. + shuffle_buffer_size: int; Size of the tf.data.dataset shuffle buffer. + shuffle_seed: int; Seed for shuffling the training data. + cache: bool; Whether to cache the dataset in memory. + **kwargs: Passed to tfds.builder(). + + Returns: + A `tf.data.Dataset`, and dataset information. + """ + return load_split_from_tfds_builder( + builder=tfds.builder(dataset_name, data_dir=data_dir, **kwargs), + batch_size=batch_size, + split=split, + preprocess_example=preprocess_example, + augment_train_example=augment_train_example, + postprocess_batch=postprocess_batch, + shuffle_buffer_size=shuffle_buffer_size, + shuffle_seed=shuffle_seed, + cache=cache) + + +def load_split_from_tfds_builder(builder, + batch_size, + split, + preprocess_example=None, + augment_train_example=None, + postprocess_batch=None, + shuffle_buffer_size=None, + shuffle_seed=0, + cache=True): + """Loads a split from a dataset using TensorFlow Datasets compatible builder. + + Args: + builder: tfds.core.DatasetBuilder; A TFDS compatible dataset builder. + batch_size: int; The batch size returned by the data pipeline. + split: str; Name of the split to be loaded. + preprocess_example: function; A function that given an example, returns the + preprocessed example. Note that the preprocessing is done BEFORE caching + to re-use them. + augment_train_example: A function that given a train example returns the + augmented example. Note that this function is applied AFTER caching and + repeat to get true randomness. + postprocess_batch: function; A function that given a batch, returns the + postprocessed batch. + shuffle_buffer_size: int; Size of the tf.data.dataset shuffle buffer. + shuffle_seed: int; Seed for shuffling the training data. + cache: bool; Whether to cache dataset in memory. + + Returns: + A `tf.data.Dataset`, and dataset information. + """ + # Prepare map functions. + preprocess_example = preprocess_example or (lambda ex: ex) + augment_train_example = augment_train_example or (lambda ex: ex) + postprocess_batch = postprocess_batch or (lambda ex: ex) + shuffle_buffer_size = shuffle_buffer_size or (8 * batch_size) + + # Download dataset: + builder.download_and_prepare() + + # Each host is responsible for a fixed subset of data. + data_range = tfds.even_splits(split, jax.process_count())[jax.process_index()] + ds = builder.as_dataset(split=data_range, shuffle_files=False) + options = tf.data.Options() + options.threading.private_threadpool_size = 48 + ds = ds.with_options(options) + + # Applying preprocessing before `ds.cache()` to re-use it. + ds = ds.map( + preprocess_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) + # Caching. + if cache: + ds = ds.cache() + + if 'train' in split: + # First repeat then batch. + ds = ds.repeat() + # Augmentation should be done after repeat for true randomness. + ds = ds.map( + augment_train_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) + # Shuffle after augmentation to avoid loading uncropped images into buffer: + ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed) + ds = ds.batch(batch_size, drop_remainder=True) + + else: + # First batch then repeat. + ds = ds.batch(batch_size, drop_remainder=False) + ds = ds.repeat() + + ds = ds.map( + postprocess_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + return ds, builder.info + + +def target_to_one_hot(batch, num_classes): + """Converts the labels to one-hot targets. + + Args: + batch: dict; A batch of data with 'inputs' and 'label'. + num_classes: int; Number of classes. + + Returns: + Batch with one-hot labels. + """ + return { + 'inputs': batch['inputs'], + 'label': common_utils.onehot(batch['label'], num_classes) + } + + +def mixup(batch: Dict['str', jnp.ndarray], + alpha: float = 1.0, + image_format: str = 'NHWC', + input_key: str = 'inputs', + label_key: str = 'label', + rng: Optional[Any] = None) -> Dict['str', jnp.ndarray]: + """Mixes images and labels within a single batch. + + For more details, please see https://arxiv.org/abs/1710.09412. + + This function supports both using `numpy` to do mixup in the input-pipeline + and `jax.numpy` to do mixup within a jitted/pmapped function (e.g. within + a pmapped train step to apply mixup on device patch). + + Results in a batch with: + mixed_images[idx] = weight * images[idx] + (1-weight) * images[-(idx+1)], + where weight is sampled from a beta distribution with parameter alpha. + + Args: + batch: dict; A batch of data with 'inputs' and 'label'. + alpha: float; Used to control the beta distribution that weight is sampled + from. + image_format: string; The format of the input images. + input_key: The key in the `batch` dictionary corresponding to the input + images. Default is `inputs`. + label_key: The key in the `batch` dictionary corresponding to the labels. + Default is `labels`. + rng: JAX rng key. If given, JAX numpy will be used as the backend, and if + None (default value), normal numpy will be used. + + Returns: + Tuple (mixed_images, mixed_labels). + """ + images, labels = batch[input_key], batch[label_key] + if labels.shape[-1] == 1: + raise ValueError('Mixup requires one-hot targets.') + if 'N' not in image_format: + raise ValueError('Mixup requires "N" to be in "image_format".') + + batch_size = labels.shape[0] + + # Set up the numpy backend and prepare mixup weights. + if rng is None: + np_backend = np # Ordinary numpy + weight = np_backend.random.beta(alpha, alpha) + else: + np_backend = jnp # JAX numpy + weight = jax.random.beta(rng, alpha, alpha) + label_weight_shape = np.ones(labels.ndim) + label_weight_shape[image_format.index('N')] = batch_size + weight *= np_backend.ones(label_weight_shape.astype(np_backend.int32)) + + # Mixup labels. + batch[label_key] = weight * labels + (1.0 - weight) * labels[::-1] + + # Mixup inputs. + # Shape calculations use np to avoid device memory fragmentation: + image_weight_shape = np.ones((images.ndim)) + image_weight_shape[image_format.index('N')] = batch_size + weight = np_backend.reshape(weight, + image_weight_shape.astype(np_backend.int32)) + reverse = tuple( + slice(images.shape[i]) if d != 'N' else slice(-1, None, -1) + for i, d in enumerate(image_format)) + batch[input_key] = weight * images + (1.0 - weight) * images[reverse] + + return batch + + +@functools.lru_cache(maxsize=None) +def get_builder(dataset, data_dir): + return tfds.builder(dataset, data_dir=data_dir, try_gcs=True) + + +def get_num_examples(dataset, split, data_dir=None): + """Returns the total number of examples in a dataset split.""" + builder = get_builder(dataset, data_dir) + # Download dataset: + builder.download_and_prepare() + num_examples = builder.info.splits[split].num_examples + remainder = num_examples % jax.process_count() + if remainder: + warning = ( + f'Dropping {remainder} examples for the ' + f'{builder.info.name} dataset, {split} split. ' + 'The reason is that all hosts should have the same number ' + 'of examples in order to guarantee that they stay in sync.' + ) + logging.warning(warning) + + return num_examples + + +def make_skip_decoders(skip_decode, features): + if skip_decode is None: + return None + elif isinstance(skip_decode, list) or isinstance(skip_decode, tuple): + return {f: tfds.decode.SkipDecoding() for f in skip_decode if f in features} + elif isinstance(skip_decode, dict): + return jax.tree_util.tree_map( + lambda _: tfds.decode.SkipDecoding(), skip_decode + ) + else: + raise ValueError( + 'skip_decode should be None, tuple, list, or dict - instead got' + f'{type(skip_decode)} {skip_decode}' + ) + + +def get_dataset_tfds( + dataset: str, + split: str, + shuffle_files: bool = True, + data_dir: Optional[str] = None, + skip_decode: Optional[Union[Sequence[str], Dict[Any, Any]]] = ('image',), +): + """Data provider.""" + builder = get_builder(dataset, data_dir) + split = tfds.even_splits(split, jax.process_count(), drop_remainder=True)[ + jax.process_index() + ] + skip_decoders = make_skip_decoders(skip_decode, builder.info.features) + # Each host is responsible for a fixed subset of data + return builder.as_dataset( + split=split, + shuffle_files=shuffle_files, + read_config=tfds.ReadConfig( + skip_prefetch=True, # We prefetch after pipeline. + try_autocache=False, # We control this, esp. for few-shot. + add_tfds_id=True, + ), + decoders=skip_decoders) + + +def make_pipeline(data, + preprocess_fn, + batch_size, + drop_remainder, + cache='loaded', + repeats=None, + repeat_after_batching=False, + shuffle_buffer_size=None, + prefetch=2, + ignore_errors=False, + dataset_service_address=None): + """Makes an input pipeline for `data`.""" + if cache not in ('loaded', 'batched', False, None): + raise ValueError(f'Unknown cache value {cache}') + + data = _add_tpu_host_options(data) + + if cache == 'loaded': + data = data.cache() + + if not repeat_after_batching: + data = data.repeat(repeats) + + if shuffle_buffer_size is not None: + data = data.shuffle(shuffle_buffer_size) + + data = data.map( + preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + if ignore_errors: + # Skip broken images. This does not slow things down. + data = data.apply(tf.data.experimental.ignore_errors()) + + data = data.batch(batch_size, drop_remainder=drop_remainder) + + if cache == 'batched': + data = data.cache() + + if repeat_after_batching: + data = data.repeat(repeats) + + if dataset_service_address: + data = distribute(data, dataset_service_address) + + if prefetch == 'autotune': + data = data.prefetch(tf.data.experimental.AUTOTUNE) + elif prefetch: + data = data.prefetch(prefetch) + # And 0 or None mean no prefetching. + + return data + + +def get_data(dataset, + split, + batch_size, + preprocess_fn=lambda x: x, + repeats=None, + shuffle_buffer_size=None, + prefetch=2, + cache='loaded', + repeat_after_batching=False, + drop_remainder=True, + data_dir=None, + ignore_errors=False, + shuffle_files=True, + dataset_service_address=None, + skip_decode=('image',)): + """API kept for backwards compatibility.""" + data = get_dataset_tfds( + dataset=dataset, + split=split, + shuffle_files=shuffle_files, + data_dir=data_dir, + skip_decode=skip_decode, + ) + if 'train' not in split: + dataset_service_address = None + return make_pipeline( + data=data, + preprocess_fn=preprocess_fn, + batch_size=batch_size, + drop_remainder=drop_remainder, + cache=cache, + repeats=repeats, + prefetch=prefetch, + shuffle_buffer_size=shuffle_buffer_size, + repeat_after_batching=repeat_after_batching, + ignore_errors=ignore_errors, + dataset_service_address=dataset_service_address) + + +def inception_crop_with_mask( + image, mask, resize_size=None, area_min=5, area_max=100): + """Applies the same inception-style crop to an image and a mask tensor. + + Inception-style crop is a random image crop (its size and aspect ratio are + random) that was used for training Inception models, see + https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf. + + Args: + image: [H, W, C] image tensor. + mask: [H, W, None] mask tensor. H and W must match the image. Will be + resized using tf.image.ResizeMethod.NEAREST_NEIGHBOR. + resize_size: Sequence of 2 ints; Resize image to [height, width] after crop. + area_min: minimal crop area. + area_max: maximal crop area. + + Returns: + Cropped image and mask tensors. + """ + begin, size, _ = tf.image.sample_distorted_bounding_box( + tf.shape(image), tf.zeros([0, 0, 4], tf.float32), + area_range=(area_min / 100, area_max / 100), + min_object_covered=0, # Don't enforce a minimum area. + use_image_if_no_bounding_boxes=True) + + # Process image: + image_cropped = tf.slice(image, begin, size) + image_cropped.set_shape([None, None, image.shape[-1]]) + if resize_size: + image_cropped = tf.image.resize( + image_cropped, resize_size, tf.image.ResizeMethod.BILINEAR) + + # Process mask: + mask_cropped = tf.slice(mask, begin, size) + mask_cropped.set_shape([None, None, mask.shape[-1]]) + if resize_size: + mask_cropped = tf.image.resize( + mask_cropped, resize_size, tf.image.ResizeMethod.NEAREST_NEIGHBOR) + + return image_cropped, mask_cropped + + +def distribute( + dataset: tf.data.Dataset, dataset_service_address: str, + processing_mode: str = 'parallel_epochs') -> tf.data.Dataset: + dataset_id = tf.data.experimental.service.register_dataset( + service=dataset_service_address, + dataset=dataset + ) + logging.info('tfds service: process %d got id %d', + jax.process_index(), dataset_id) + return tf.data.experimental.service.from_dataset_id( + processing_mode=processing_mode, + service=dataset_service_address, + dataset_id=dataset_id, + job_name='scenic_data_pipeline', + element_spec=dataset.element_spec) + + +def _add_tpu_host_options(data): + options = tf.data.Options() + options.threading.private_threadpool_size = 48 + options.threading.max_intra_op_parallelism = 1 + return data.with_options(options) diff --git a/scenic/dataset_lib/datasets.py b/scenic/dataset_lib/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..1882935bfb6d30957990f092ae4916c1889781fc --- /dev/null +++ b/scenic/dataset_lib/datasets.py @@ -0,0 +1,144 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for Scenic.""" + +import functools +import importlib +from typing import Callable, List + +from absl import logging +from scenic.dataset_lib import dataset_utils + +# The dict below hardcodes import that define datasets. This is necessary for +# several reasons: +# 1) Datasets are only registered once they are defined (have been imported). +# 2) We don't want the user code (e.g. trainers / projects) to have to import +# the dataset modules. Instead we'd like to do it for them. +# 3) And finally we don't want to import all datasets available to unless if the +# the user code does not need them. +# TODO(b/186631707): This routing table is not a great solution because it +# requires every new dataset to modify this import routing table. Going forward +# we should find a way to avoid that. +_IMPORT_TABLE = { + 'cifar10': 'scenic.dataset_lib.cifar10_dataset', + 'cityscapes': 'scenic.dataset_lib.cityscapes_dataset', + 'imagenet': 'scenic.dataset_lib.imagenet_dataset', + 'fashion_mnist': 'scenic.dataset_lib.fashion_mnist_dataset', + 'mnist': 'scenic.dataset_lib.mnist_dataset', + 'bair': 'scenic.dataset_lib.bair_dataset', + 'oxford_pets': 'scenic.dataset_lib.oxford_pets_dataset', + 'svhn': 'scenic.dataset_lib.svhn_dataset', + 'video_tfrecord_dataset': ( + 'scenic.projects.vivit.data.video_tfrecord_dataset' + ), + 'av_asr_tfrecord_dataset': ( + 'scenic.projects.avatar.datasets.av_asr_tfrecord_dataset' + ), + 'bit': 'scenic.dataset_lib.big_transfer.bit', + 'bert_wikibooks': ( + 'scenic.projects.baselines.bert.datasets.bert_wikibooks_dataset' + ), + 'bert_glue': 'scenic.projects.baselines.bert.datasets.bert_glue_dataset', + 'coco_detr_detection': ( + 'scenic.projects.baselines.detr.input_pipeline_detection' + ), + 'cityscapes_variants': ( + 'scenic.projects.robust_segvit.datasets.cityscapes_variants' + ), + 'robust_segvit_segmentation': ( + 'scenic.projects.robust_segvit.datasets.segmentation_datasets' + ), + 'robust_segvit_variants': ( + 'scenic.projects.robust_segvit.datasets.segmentation_variants' + ), + 'flexio': 'scenic.dataset_lib.flexio.flexio', +} + + +class DatasetRegistry(object): + """Static class for keeping track of available datasets.""" + _REGISTRY = {} + + @classmethod + def add(cls, name: str, builder_fn: Callable[..., dataset_utils.Dataset]): + """Add a dataset to the registry, i.e. register a dataset. + + Args: + name: Dataset name (must be unique). + builder_fn: Function to be called to construct the datasets. Must accept + dataset-specific arguments and return a dataset description. + + Raises: + KeyError: If the provided name is not unique. + """ + if name in cls._REGISTRY: + raise KeyError(f'Dataset with name ({name}) already registered.') + cls._REGISTRY[name] = builder_fn + + @classmethod + def get(cls, name: str) -> Callable[..., dataset_utils.Dataset]: + """Get a dataset from the registry by its name. + + Args: + name: Dataset name. + + Returns: + Dataset builder function that accepts dataset-specific parameters and + returns a dataset description. + + Raises: + KeyError: If the dataset is not found. + """ + if name not in cls._REGISTRY: + if name in _IMPORT_TABLE: + module = _IMPORT_TABLE[name] + importlib.import_module(module) + logging.info( + 'On-demand import of dataset (%s) from module (%s).', name, module) + if name not in cls._REGISTRY: + raise KeyError(f'Imported module ({module}) did not register dataset' + f'({name}). Please check that dataset names match.') + else: + raise KeyError(f'Unknown dataset ({name}). Did you import the dataset ' + f'module explicitly?') + return cls._REGISTRY[name] + + @classmethod + def list(cls) -> List[str]: + """List registered datasets.""" + return list(cls._REGISTRY.keys()) + + +def add_dataset(name: str, *args, **kwargs): + """Decorator for shorthand dataset registdation.""" + def inner(builder_fn: Callable[..., dataset_utils.Dataset] + ) -> Callable[..., dataset_utils.Dataset]: + DatasetRegistry.add(name, functools.partial(builder_fn, *args, **kwargs)) + return builder_fn + return inner + + +def get_dataset(dataset_name: str) -> Callable[..., dataset_utils.Dataset]: + """Maps dataset name to a dataset_builder. + + API kept for compatibility of existing code with the DatasetRegistry. + + Args: + dataset_name: Dataset name. + + Returns: + A dataset builder. + """ + return DatasetRegistry.get(dataset_name) diff --git a/scenic/dataset_lib/fashion_mnist_dataset.py b/scenic/dataset_lib/fashion_mnist_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..04bce164ba27fbf61f7f9637ee39b9c9521bad44 --- /dev/null +++ b/scenic/dataset_lib/fashion_mnist_dataset.py @@ -0,0 +1,126 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for the fashion-MNIST dataset.""" + +import functools +from typing import Optional + +from absl import logging +import jax.numpy as jnp +from scenic.dataset_lib import dataset_utils +from scenic.dataset_lib import datasets +import tensorflow as tf + + +def preprocess_example(example, dtype=tf.float32): + """Preprocesses the given image. + + Args: + example: dict; Example that has an 'image' and a 'label'. + dtype: Tensorflow data type; Data type of the image. + + Returns: + A preprocessed image `Tensor`. + """ + image = dataset_utils.normalize(example['image'], dtype) + return {'inputs': image, 'label': example['label']} + + +@datasets.add_dataset('fashion_mnist') +def get_dataset(*, + batch_size, + eval_batch_size, + num_shards, + dtype_str='float32', + shuffle_seed=0, + rng=None, + dataset_configs=None, + dataset_service_address: Optional[str] = None): + """Returns generators for the fashion-MNIST train, validation, and test set. + + Args: + batch_size: int; Determines the train batch size. + eval_batch_size: int; Determines the evaluation batch size. + num_shards: int; Number of shards --> batch shape: [num_shards, bs, ...]. + dtype_str: Data type of the image (e.g. 'float32'). + shuffle_seed: int; Seed for shuffling the training data. + rng: JAX rng key, which can be used for augmentation, shuffling, etc. + dataset_configs: dict; Dataset specific configurations. + dataset_service_address: If set, will distribute the training dataset using + the given tf.data service at the given address. + + Returns: + A dataset_utils.Dataset() which includes a train_iter, a valid_iter, + a test_iter, and a dict of meta_data. + """ + del rng + del dataset_configs + dtype = getattr(tf, dtype_str) + + preprocess_ex = functools.partial(preprocess_example, dtype=dtype) + logging.info('Loading train split of the Fashion-MNIST dataset.') + train_ds, train_ds_info = dataset_utils.load_split_from_tfds( + 'fashion_mnist', + batch_size, + split='train', + preprocess_example=preprocess_ex, + shuffle_seed=shuffle_seed) + if dataset_service_address: + if shuffle_seed is not None: + raise ValueError('Using dataset service with a random seed causes each ' + 'worker to produce exactly the same data. Add ' + 'config.shuffle_seed = None to your config if you ' + 'want to run with dataset service.') + logging.info('Using the tf.data service at %s', dataset_service_address) + train_ds = dataset_utils.distribute(train_ds, dataset_service_address) + + logging.info('Loading test split of the Fashion-MNIST dataset.') + eval_ds, _ = dataset_utils.load_split_from_tfds( + 'fashion_mnist', + eval_batch_size, + split='test', + preprocess_example=preprocess_ex) + + shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards) + maybe_pad_batches_train = functools.partial( + dataset_utils.maybe_pad_batch, train=True, batch_size=batch_size) + maybe_pad_batches_eval = functools.partial( + dataset_utils.maybe_pad_batch, train=False, batch_size=eval_batch_size) + train_iter = iter(train_ds) + train_iter = map(dataset_utils.tf_to_numpy, train_iter) + train_iter = map(maybe_pad_batches_train, train_iter) + train_iter = map(shard_batches, train_iter) + + eval_iter = iter(eval_ds) + eval_iter = map(dataset_utils.tf_to_numpy, eval_iter) + eval_iter = map(maybe_pad_batches_eval, eval_iter) + eval_iter = map(shard_batches, eval_iter) + + input_shape = (-1, 28, 28, 1) + meta_data = { + 'num_classes': + train_ds_info.features['label'].num_classes, + 'input_shape': + input_shape, + 'num_train_examples': + dataset_utils.get_num_examples('fashion_mnist', 'train'), + 'num_eval_examples': + dataset_utils.get_num_examples('fashion_mnist', 'test'), + 'input_dtype': + getattr(jnp, dtype_str), + 'target_is_onehot': + False, + } + return dataset_utils.Dataset(train_iter, eval_iter, None, meta_data) diff --git a/scenic/dataset_lib/flexio/README.md b/scenic/dataset_lib/flexio/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/flexio/flexio.py b/scenic/dataset_lib/flexio/flexio.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/imagenet_dataset.py b/scenic/dataset_lib/imagenet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bb76a459008914d20cf32a74bd5c3a1bd3eb78e1 --- /dev/null +++ b/scenic/dataset_lib/imagenet_dataset.py @@ -0,0 +1,376 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for the ImageNet dataset.""" + +import functools +from typing import Optional + +from absl import logging +from flax import jax_utils +import jax +import jax.numpy as jnp +from scenic.dataset_lib import dataset_utils +from scenic.dataset_lib import datasets +import tensorflow as tf +import tensorflow_datasets as tfds + +TRAIN_IMAGES = 1281167 +EVAL_IMAGES = 50000 +NUM_CLASSES = 1000 + +IMAGE_SIZE = 224 +CROP_PADDING = 32 +MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] +STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + +def distorted_bounding_box_crop(image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0), + max_attempts=100): + """Generates cropped_image using one of the bboxes randomly distorted. + + See `tf.image.sample_distorted_bounding_box` for more documentation. + + Args: + image_bytes: TF tensor; Binary image data. + bbox: `Tensor; Bounding boxes arranged `[1, num_boxes, coords]` where each + coordinate is [0, 1) and the coordinates are arranged as `[ymin, xmin, + ymax, xmax]`. If num_boxes is 0 then use the whole image. + min_object_covered: float; Defaults to `0.1`. The cropped area of the image + must contain at least this fraction of any bounding box supplied. + aspect_ratio_range: list[float]; The cropped area of the image must have an + aspect ratio = width / height within this range. + area_range: list[float]; The cropped area of the image must contain a + fraction of the supplied image within in this range. + max_attempts: int; Number of attempts at generating a cropped region of the + image of the specified constraints. After `max_attempts` failures, return + the entire image. + + Returns: + Cropped image TF Tensor. + """ + shape = tf.image.extract_jpeg_shape(image_bytes) + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + shape, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + + return image + + +def _resize(image, image_size): + """Resizes the image. + + Args: + image: Tensor; Input image. + image_size: int; Image size. + + Returns: + Resized image. + """ + return tf.image.resize([image], [image_size, image_size], + method=tf.image.ResizeMethod.BICUBIC)[0] + + +def _at_least_x_are_equal(a, b, x): + """At least `x` of `a` and `b` `Tensors` are equal.""" + match = tf.equal(a, b) + match = tf.cast(match, tf.int32) + return tf.greater_equal(tf.reduce_sum(match), x) + + +def _decode_and_random_crop(image_bytes, image_size): + """Make a random crop of `image_size`.""" + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + image = distorted_bounding_box_crop( + image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(3. / 4, 4. / 3.), + area_range=(0.08, 1.0), + max_attempts=10) + original_shape = tf.image.extract_jpeg_shape(image_bytes) + bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) + + image = tf.cond(bad, lambda: _decode_and_center_crop(image_bytes, image_size), + lambda: _resize(image, image_size)) + + return image + + +def _decode_and_center_crop(image_bytes, image_size): + """Crops to center of image with padding then scales `image_size`.""" + shape = tf.image.extract_jpeg_shape(image_bytes) + image_height = shape[0] + image_width = shape[1] + + padded_center_crop_size = tf.cast( + ((image_size / (image_size + CROP_PADDING)) * + tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + crop_window = tf.stack([ + offset_height, offset_width, padded_center_crop_size, + padded_center_crop_size + ]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + image = _resize(image, image_size) + + return image + + +def normalize_image(image): + image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype) + image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype) + return image + + +def preprocess_for_train(image_bytes, + dtype=tf.float32, + image_size=IMAGE_SIZE, + data_augmentations=None): + """Preprocesses the given image for training. + + Args: + image_bytes: Tensor; Representing an image binary of arbitrary size. + dtype: TF data type; Data type of the image. + image_size: int; The target size of the images. + data_augmentations: list(str); Types of data augmentation applied on + training data. + + Returns: + A preprocessed image `Tensor`. + """ + if data_augmentations is not None: + if 'default' in data_augmentations: + image = _decode_and_random_crop(image_bytes, image_size) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.random_flip_left_right(image) + else: + image = _decode_and_center_crop(image_bytes, image_size) + image = tf.reshape(image, [image_size, image_size, 3]) + + if dtype not in [tf.int32, tf.int64, tf.uint32, tf.uint64]: + image = normalize_image(image) + image = tf.image.convert_image_dtype(image, dtype=dtype) + else: + image = tf.cast(image, dtype=dtype) + return image + + +def preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: Tensor; Representing an image binary of arbitrary size. + dtype: TF data type; Data type of the image. + image_size: int; The target size of the images. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_center_crop(image_bytes, image_size) + image = tf.reshape(image, [image_size, image_size, 3]) + if dtype not in [tf.int32, tf.int64, tf.uint32, tf.uint64]: + image = normalize_image(image) + image = tf.image.convert_image_dtype(image, dtype=dtype) + else: + image = tf.cast(image, dtype=dtype) + return image + + +def imagenet_load_split(batch_size, + train, + onehot_labels, + dtype=tf.float32, + image_size=IMAGE_SIZE, + prefetch_buffer_size=10, + shuffle_seed=None, + data_augmentations=None): + """Creates a split from the ImageNet dataset using TensorFlow Datasets. + + For the training set, we drop the last partial batch. This is fine to do + because we additionally shuffle the data randomly each epoch, thus the trainer + will see all data in expectation. For the validation set, we pad the final + batch to the desired batch size. + + Args: + batch_size: int; The batch size returned by the data pipeline. + train: bool; Whether to load the train or evaluation split. + onehot_labels: Whether to transform the labels to one hot. + dtype: TF data type; Data type of the image. + image_size: int; The target size of the images. + prefetch_buffer_size: int; Buffer size for the TFDS prefetch. + shuffle_seed: The seed to use when shuffling the train split. + data_augmentations: list(str); Types of data augmentation applied on + training data. + + Returns: + A `tf.data.Dataset`. + """ + if train: + split_size = TRAIN_IMAGES // jax.process_count() + start = jax.process_index() * split_size + split = 'train[{}:{}]'.format(start, start + split_size) + else: + split_size = EVAL_IMAGES // jax.process_count() + start = jax.process_index() * split_size + split = 'validation[{}:{}]'.format(start, start + split_size) + + def decode_example(example): + if train: + image = preprocess_for_train(example['image'], dtype, image_size, + data_augmentations) + else: + image = preprocess_for_eval(example['image'], dtype, image_size) + + label = example['label'] + label = tf.one_hot(label, NUM_CLASSES) if onehot_labels else label + return {'inputs': image, 'label': label} + + dataset_builder = tfds.builder('imagenet2012:5.*.*') + # Download dataset: + dataset_builder.download_and_prepare() + ds = dataset_builder.as_dataset( + split=split, decoders={ + 'image': tfds.decode.SkipDecoding(), + }) + options = tf.data.Options() + options.threading.private_threadpool_size = 48 + ds = ds.with_options(options) + + ds = ds.cache() + + if train: + ds = ds.repeat() + ds = ds.shuffle(16 * batch_size, seed=shuffle_seed) + + # decode_example should be applied after caching as it also does augmentation + ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = ds.batch(batch_size, drop_remainder=train) + + if not train: + ds = ds.repeat() + + ds = ds.prefetch(prefetch_buffer_size) + return ds + + +@datasets.add_dataset('imagenet') +def get_dataset(*, + batch_size, + eval_batch_size, + num_shards, + dtype_str='float32', + shuffle_seed=0, + rng=None, + prefetch_buffer_size=2, + dataset_configs=None, + dataset_service_address: Optional[str] = None): + """Returns generators for the ImageNet train, validation, and test sets. + + Args: + batch_size: int; Determines the train batch size. + eval_batch_size: int; Determines the evaluation batch size. + num_shards: int; Number of shards --> batch shape: [num_shards, bs, ...]. + dtype_str: Data type of the image (e.g. 'float32'). + shuffle_seed: int; Seed for shuffling the training data. + rng: JAX rng key, which can be used for augmentation, shuffling, etc. + prefetch_buffer_size: int; Buffer size for the device prefetch. + dataset_configs: dict; Dataset specific configurations. + dataset_service_address: If set, will distribute the training dataset using + the given tf.data service at the given address. + + Returns: + A dataset_utils.Dataset() which includes a train_iter, a valid_iter, + a test_iter, and a dict of meta_data. + """ + dataset_configs = dataset_configs or {} + del rng + data_augmentations = dataset_configs.get('data_augmentations', ['default']) + # TODO(dehghani): add mixup data augmentation. + for da in data_augmentations: + if da not in ['default']: + raise ValueError(f'Data augmentation {data_augmentations} is not ' + f'(yet) supported in the ImageNet dataset.') + dtype = getattr(tf, dtype_str) + onehot_labels = dataset_configs.get('onehot_labels', False) + + logging.info('Loading train split of the ImageNet dataset.') + train_ds = imagenet_load_split( + batch_size, + train=True, + onehot_labels=onehot_labels, + dtype=dtype, + shuffle_seed=shuffle_seed, + data_augmentations=data_augmentations) + + if dataset_service_address: + if shuffle_seed is not None: + raise ValueError('Using dataset service with a random seed causes each ' + 'worker to produce exactly the same data. Add ' + 'config.shuffle_seed = None to your config if you ' + 'want to run with dataset service.') + logging.info('Using the tf.data service at %s', dataset_service_address) + train_ds = dataset_utils.distribute(train_ds, dataset_service_address) + + logging.info('Loading test split of the ImageNet dataset.') + eval_ds = imagenet_load_split(eval_batch_size, train=False, + onehot_labels=onehot_labels, dtype=dtype) + + maybe_pad_batches_train = functools.partial( + dataset_utils.maybe_pad_batch, train=True, batch_size=batch_size) + maybe_pad_batches_eval = functools.partial( + dataset_utils.maybe_pad_batch, train=False, batch_size=eval_batch_size) + shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards) + + train_iter = iter(train_ds) + train_iter = map(dataset_utils.tf_to_numpy, train_iter) + train_iter = map(maybe_pad_batches_train, train_iter) + train_iter = map(shard_batches, train_iter) + train_iter = jax_utils.prefetch_to_device(train_iter, prefetch_buffer_size) + + eval_iter = iter(eval_ds) + eval_iter = map(dataset_utils.tf_to_numpy, eval_iter) + eval_iter = map(maybe_pad_batches_eval, eval_iter) + eval_iter = map(shard_batches, eval_iter) + eval_iter = jax_utils.prefetch_to_device(eval_iter, prefetch_buffer_size) + + input_shape = (-1, IMAGE_SIZE, IMAGE_SIZE, 3) + + meta_data = { + 'num_classes': NUM_CLASSES, + 'input_shape': input_shape, + 'num_train_examples': TRAIN_IMAGES, + 'num_eval_examples': EVAL_IMAGES, + 'input_dtype': getattr(jnp, dtype_str), + 'target_is_onehot': onehot_labels, + } + return dataset_utils.Dataset(train_iter, eval_iter, None, meta_data) diff --git a/scenic/dataset_lib/mnist_dataset.py b/scenic/dataset_lib/mnist_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c7362dd76796dab122a4d4fc0561ba4840b1e35f --- /dev/null +++ b/scenic/dataset_lib/mnist_dataset.py @@ -0,0 +1,119 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for the MNIST dataset.""" + +import functools +from typing import Optional + +from absl import logging +import jax.numpy as jnp +from scenic.dataset_lib import dataset_utils +from scenic.dataset_lib import datasets +import tensorflow as tf + + +def preprocess_example(example, dtype=tf.float32): + """Preprocesses the given image. + + Args: + example: dict; Example that has an 'image' and a 'label'. + dtype: Tensorflow data type; Data type of the image. + + Returns: + A preprocessed image `Tensor`. + """ + image = dataset_utils.normalize(example['image'], dtype) + return {'inputs': image, 'label': example['label']} + + +@datasets.add_dataset('mnist') +def get_dataset(*, + batch_size, + eval_batch_size, + num_shards, + dtype_str='float32', + shuffle_seed=0, + rng=None, + dataset_configs=None, + dataset_service_address: Optional[str] = None): + """Returns generators for the MNIST train, validation, and test set. + + Args: + batch_size: int; Determines the train batch size. + eval_batch_size: int; Determines the evaluation batch size. + num_shards: int; Number of shards --> batch shape: [num_shards, bs, ...]. + dtype_str: Data type of the image (e.g. 'float32'). + shuffle_seed: int; Seed for shuffling the training data. + rng: JAX rng key, which can be used for augmentation, shuffling, etc. + dataset_configs: dict; Dataset specific configurations. + dataset_service_address: If set, will distribute the training dataset using + the given tf.data service at the given address. + + Returns: + A dataset_utils.Dataset() which includes a train_iter, a valid_iter, + a test_iter, and a dict of meta_data. + """ + del rng + del dataset_configs + dtype = getattr(tf, dtype_str) + preprocess_ex = functools.partial(preprocess_example, dtype=dtype) + + logging.info('Loading train split of the MNIST dataset.') + train_ds, train_ds_info = dataset_utils.load_split_from_tfds( + 'mnist', + batch_size, + split='train', + preprocess_example=preprocess_ex, + shuffle_seed=shuffle_seed) + + if dataset_service_address: + if shuffle_seed is not None: + raise ValueError('Using dataset service with a random seed causes each ' + 'worker to produce exactly the same data. Add ' + 'config.shuffle_seed = None to your config if you ' + 'want to run with dataset service.') + logging.info('Using the tf.data service at %s', dataset_service_address) + train_ds = dataset_utils.distribute(train_ds, dataset_service_address) + + logging.info('Loading test split of the MNIST dataset.') + eval_ds, _ = dataset_utils.load_split_from_tfds( + 'mnist', eval_batch_size, split='test', preprocess_example=preprocess_ex) + + maybe_pad_batches_train = functools.partial( + dataset_utils.maybe_pad_batch, train=True, batch_size=batch_size) + maybe_pad_batches_eval = functools.partial( + dataset_utils.maybe_pad_batch, train=False, batch_size=eval_batch_size) + shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards) + + train_iter = iter(train_ds) + train_iter = map(dataset_utils.tf_to_numpy, train_iter) + train_iter = map(maybe_pad_batches_train, train_iter) + train_iter = map(shard_batches, train_iter) + + eval_iter = iter(eval_ds) + eval_iter = map(dataset_utils.tf_to_numpy, eval_iter) + eval_iter = map(maybe_pad_batches_eval, eval_iter) + eval_iter = map(shard_batches, eval_iter) + + input_shape = (-1, 28, 28, 1) + meta_data = { + 'num_classes': train_ds_info.features['label'].num_classes, + 'input_shape': input_shape, + 'num_train_examples': dataset_utils.get_num_examples('mnist', 'train'), + 'num_eval_examples': dataset_utils.get_num_examples('mnist', 'test'), + 'input_dtype': getattr(jnp, dtype_str), + 'target_is_onehot': False, + } + return dataset_utils.Dataset(train_iter, eval_iter, None, meta_data) diff --git a/scenic/dataset_lib/oxford_pets_dataset.py b/scenic/dataset_lib/oxford_pets_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e688ee2e57a6104c7adb2eec75ded481493a2386 --- /dev/null +++ b/scenic/dataset_lib/oxford_pets_dataset.py @@ -0,0 +1,144 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for the Oxford-IIIT pet dataset.""" + +import functools +from typing import Optional + +from absl import logging +import jax.numpy as jnp +from scenic.dataset_lib import dataset_utils +from scenic.dataset_lib import datasets +import tensorflow as tf + +IMAGE_SIZE = [224, 224] + + +def preprocess_example(example, dtype=tf.float32): + """Preprocesses the given image. + + Args: + example: dict; Example coming from TFDS. + dtype: Tensorflow data type; Data type of the image. + + Returns: + An example dict as required by the model. + """ + example_out = {} + # For simplicity, just resize all images to the same shape: + example_out['inputs'] = tf.image.resize( + dataset_utils.normalize(example['image'], dtype), IMAGE_SIZE, 'bilinear') + example_out['inputs'] = tf.cast(example_out['inputs'], dtype) + + example_out['label'] = tf.image.resize( + example['segmentation_mask'], IMAGE_SIZE, 'nearest') + example_out['label'] = tf.squeeze(example_out['label'], axis=2) + example_out['label'] = tf.cast(example_out['label'], dtype) + + # The dataset has three classes: object/pet (label 1), background (label 2) + # and object outline (label 3). Convert to zero-indexed labels: + example_out['label'] -= 1 + + return example_out + + +@datasets.add_dataset('oxford_pets') +def get_dataset(*, + batch_size, + eval_batch_size, + num_shards, + dtype_str='float32', + shuffle_seed=0, + rng=None, + dataset_configs=None, + dataset_service_address: Optional[str] = None): + """Returns generators for the Oxford Pet train, validation, and test set. + + Args: + batch_size: int; Determines the train batch size. + eval_batch_size: int; Determines the evaluation batch size. + num_shards: int; Number of shards --> batch shape: [num_shards, bs, ...]. + dtype_str: Data type of the image (e.g. 'float32'). + shuffle_seed: int; Seed for shuffling the training data. + rng: JAX rng key, which can be used for augmentation, shuffling, etc. + dataset_configs: dict; Dataset specific configurations. + dataset_service_address: If set, will distribute the training dataset using + the given tf.data service at the given address. + + Returns: + A dataset_utils.Dataset() which includes a train_iter, a valid_iter, + a test_iter, and a dict of meta_data. + """ + del rng + del dataset_configs + dtype = getattr(tf, dtype_str) + preprocess_ex = functools.partial(preprocess_example, dtype=dtype) + + logging.info('Loading train split of the Oxford Pet dataset.') + train_ds, _ = dataset_utils.load_split_from_tfds( + 'oxford_iiit_pet', + batch_size, + split='train', + preprocess_example=preprocess_ex, + shuffle_seed=shuffle_seed) + + if dataset_service_address: + if shuffle_seed is not None: + raise ValueError('Using dataset service with a random seed causes each ' + 'worker to produce exactly the same data. Add ' + 'config.shuffle_seed = None to your config if you ' + 'want to run with dataset service.') + logging.info('Using the tf.data service at %s', dataset_service_address) + train_ds = dataset_utils.distribute(train_ds, dataset_service_address) + + logging.info('Loading test split of the Oxford Pet dataset.') + eval_ds, _ = dataset_utils.load_split_from_tfds( + 'oxford_iiit_pet', eval_batch_size, split='test', + preprocess_example=preprocess_ex) + + maybe_pad_batches_train = functools.partial( + dataset_utils.maybe_pad_batch, train=True, batch_size=batch_size, + pixel_level=True) + maybe_pad_batches_eval = functools.partial( + dataset_utils.maybe_pad_batch, train=False, batch_size=eval_batch_size, + pixel_level=True) + shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards) + + train_iter = iter(train_ds) + train_iter = map(dataset_utils.tf_to_numpy, train_iter) + train_iter = map(maybe_pad_batches_train, train_iter) + train_iter = map(shard_batches, train_iter) + + eval_iter = iter(eval_ds) + eval_iter = map(dataset_utils.tf_to_numpy, eval_iter) + eval_iter = map(maybe_pad_batches_eval, eval_iter) + eval_iter = map(shard_batches, eval_iter) + + input_shape = (-1, IMAGE_SIZE[0], IMAGE_SIZE[1], 3) + meta_data = { + 'num_classes': + 3, + 'input_shape': + input_shape, + 'num_train_examples': + dataset_utils.get_num_examples('oxford_iiit_pet', 'train'), + 'num_eval_examples': + dataset_utils.get_num_examples('oxford_iiit_pet', 'test'), + 'input_dtype': + getattr(jnp, dtype_str), + 'target_is_onehot': + False, + } + return dataset_utils.Dataset(train_iter, eval_iter, None, meta_data) diff --git a/scenic/dataset_lib/svhn_dataset.py b/scenic/dataset_lib/svhn_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3ea2e192d8c63fc2792ca260e8ef53d336127903 --- /dev/null +++ b/scenic/dataset_lib/svhn_dataset.py @@ -0,0 +1,172 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for the SVHN dataset. + +The Street View House Numbers (SVHN) Dataset is an image digit recognition + dataset of over 600,000 color digit images coming from real world data. + Split size: + - Training set: 73,257 images + - Testing set: 26,032 images + - Extra training set: 531,131 images + Following the common setup on SVHN, we only use the official training and + testing data. Images are cropped to 32x32. + + URL: http://ufldl.stanford.edu/housenumbers/ +""" + +import functools +from typing import Optional + +from absl import logging +import jax.numpy as jnp +from scenic.dataset_lib import dataset_utils +from scenic.dataset_lib import datasets +import tensorflow as tf + + +def preprocess_example(example, dtype=tf.float32): + """Preprocesses the given example. + + Args: + example: dict; Example that has an 'image' and a 'label'. + dtype: Tensorflow data type; Data type of the image. + + Returns: + A preprocessed image `Tensor`. + """ + image = tf.cast(example['image'], dtype=dtype) + if dtype not in [tf.int32, tf.int64, tf.uint32, tf.uint64]: + image /= tf.constant(255.0, shape=[1, 1, 1], dtype=dtype) + return {'inputs': image, 'label': example['label']} + + +def augment_example(example, dtype=tf.float32, data_augmentations=None): + """Apply data augmentation on the given training example. + + Args: + example: dict; Example that has an 'image' and a 'label'. + dtype: Tensorflow data type; Data type of the image. + data_augmentations: list(str); Types of data augmentation applied on + training data. + + Returns: + An augmented training example. + """ + image = tf.cast(example['inputs'], dtype=dtype) + if data_augmentations is not None: + if 'random_crop_flip' in data_augmentations: + image = dataset_utils.augment_random_crop_flip( + image, crop_padding=4, flip=True) + image = tf.cast(image, dtype=dtype) + return {'inputs': image, 'label': example['label']} + + +@datasets.add_dataset('svhn') +def get_dataset(*, + batch_size, + eval_batch_size, + num_shards, + dtype_str='float32', + shuffle_seed=0, + rng=None, + dataset_configs=None, + dataset_service_address: Optional[str] = None): + """Returns generators for the SVHN train, validation, and test set. + + Args: + batch_size: int; Determines the train batch size. + eval_batch_size: int; Determines the evaluation batch size. + num_shards: int; Number of shards --> batch shape: [num_shards, bs, ...]. + dtype_str: Data type of the image (e.g. 'float32'). + shuffle_seed: int; Seed for shuffling the training data. + rng: JAX rng key, which can be used for augmentation, shuffling, etc. + dataset_configs: dict; Dataset specific configurations. + dataset_service_address: If set, will distribute the training dataset using + the given tf.data service at the given address. + + Returns: + A dataset_utils.Dataset() which includes a train_iter, a valid_iter, + a test_iter, and a dict of meta_data. + """ + del rng + dataset_configs = dataset_configs or {} + data_augmentations = dataset_configs.get('data_augmentations', []) + for da in data_augmentations: + if da not in ['random_crop_flip']: + raise ValueError(f'Data augmentation type {da} is not yet supported ' + f'in the SVHN dataset.') + + dtype = getattr(tf, dtype_str) + preprocess_ex = functools.partial(preprocess_example, dtype=dtype) + + logging.info('Loading train split of the SVHN dataset.') + augment_ex = functools.partial( + augment_example, dtype=dtype, data_augmentations=data_augmentations) + train_ds, train_ds_info = dataset_utils.load_split_from_tfds( + 'svhn_cropped:3.*.*', + batch_size, + split='train', + preprocess_example=preprocess_ex, + augment_train_example=augment_ex, + shuffle_seed=shuffle_seed) + + if dataset_service_address: + if shuffle_seed is not None: + raise ValueError('Using dataset service with a random seed causes each ' + 'worker to produce exactly the same data. Add ' + 'config.shuffle_seed = None to your config if you ' + 'want to run with dataset service.') + logging.info('Using the tf.data service at %s', dataset_service_address) + train_ds = dataset_utils.distribute(train_ds, dataset_service_address) + + logging.info('Loading test split of the SVHN dataset.') + eval_ds, _ = dataset_utils.load_split_from_tfds( + 'svhn_cropped:3.*.*', + eval_batch_size, + split='test', + preprocess_example=preprocess_ex) + + shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards) + maybe_pad_batches_train = functools.partial( + dataset_utils.maybe_pad_batch, train=True, batch_size=batch_size) + maybe_pad_batches_eval = functools.partial( + dataset_utils.maybe_pad_batch, train=False, batch_size=eval_batch_size) + + train_iter = iter(train_ds) + train_iter = map(dataset_utils.tf_to_numpy, train_iter) + train_iter = map(maybe_pad_batches_train, train_iter) + train_iter = map(shard_batches, train_iter) + + eval_iter = iter(eval_ds) + eval_iter = map(dataset_utils.tf_to_numpy, eval_iter) + eval_iter = map(maybe_pad_batches_eval, eval_iter) + eval_iter = map(shard_batches, eval_iter) + + input_shape = (-1, 32, 32, 3) + meta_data = { + 'num_classes': + train_ds_info.features['label'].num_classes, + 'input_shape': + input_shape, + 'num_train_examples': + dataset_utils.get_num_examples('svhn_cropped:3.*.*', 'train'), + 'num_eval_examples': + dataset_utils.get_num_examples('svhn_cropped:3.*.*', 'test'), + 'input_dtype': + getattr(jnp, dtype_str), + 'target_is_onehot': + False, + } + return dataset_utils.Dataset(train_iter, eval_iter, None, meta_data) diff --git a/scenic/dataset_lib/tests/__init__.py b/scenic/dataset_lib/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/tests/test_dataset_utils.py b/scenic/dataset_lib/tests/test_dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/dataset_lib/video_ops.py b/scenic/dataset_lib/video_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..80683adcdb2a2bd0e227c9f1b95a61d8a50dce85 --- /dev/null +++ b/scenic/dataset_lib/video_ops.py @@ -0,0 +1,842 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preprocessing functions for video data loading. + +Includes SimCLR-style data augmentation functions adapted to be temporally +consistent throughout the video. + +Code is based on: +SimCLR style data augmentation is based on: +https://github.com/google-research/simclr/blob/master/tf2/data_util.py +""" + +import functools +import math +from typing import Optional + + +from absl import logging +from dmvr import builders +from dmvr import processors as dmvr_processors +import simclr.tf2.data_util as simclr_data +import tensorflow as tf +from official.vision.image_classification import augment + + +def _get_shape(x): + """Gets tensor shape as a list, allowing mixing static and dynamic shapes.""" + dynamic_shape = tf.shape(x) + if x.shape.ndims is None: + return dynamic_shape + static_shape = x.shape.as_list() + shapes = [ + static_shape[i] if static_shape[i] is not None else dynamic_shape[i] + for i in range(x.shape.ndims) + ] + return shapes + + +def _fill_rectangle_video(image, + center_width, + center_height, + half_width, + half_height, + replace=None): + """Fills blank area for video.""" + image_time = tf.shape(image)[0] + image_height = tf.shape(image)[1] + image_width = tf.shape(image)[2] + + lower_pad = tf.maximum(0, center_height - half_height) + upper_pad = tf.maximum(0, image_height - center_height - half_height) + left_pad = tf.maximum(0, center_width - half_width) + right_pad = tf.maximum(0, image_width - center_width - half_width) + + cutout_shape = [ + image_time, image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad) + ] + padding_dims = [[0, 0], [lower_pad, upper_pad], [left_pad, right_pad]] + mask = tf.pad( + tf.zeros(cutout_shape, dtype=image.dtype), + padding_dims, + constant_values=1) + mask = tf.expand_dims(mask, -1) + mask = tf.tile(mask, [1, 1, 1, 3]) + + if replace is None: + fill = tf.random.normal(tf.shape(image), dtype=image.dtype) + elif isinstance(replace, tf.Tensor): + fill = replace + else: + fill = tf.ones_like(image, dtype=image.dtype) * replace + image = tf.where(tf.equal(mask, 0), fill, image) + + return image + + +class RandomErasing: + """Applies RandomErasing to a video. + + + Reference: https://arxiv.org/abs/1708.04896 + """ + + def __init__(self, + probability: float = 0.25, + min_area: float = 0.02, + max_area: float = 1 / 3, + min_aspect: float = 0.3, + max_aspect: Optional[float] = None, + min_count=1, + max_count=1, + trials=10): + """Applies RandomErasing to a video. + + Args: + probability: Probability of augmenting the image. Defaults to `0.25`. + min_area: Minimum area of the random erasing rectangle. Defaults to + `0.02`. + max_area: Maximum area of the random erasing rectangle. Defaults to `1/3`. + min_aspect: Minimum aspect rate of the random erasing rectangle. Defaults + to `0.3`. + max_aspect: Maximum aspect rate of the random erasing rectangle. Defaults + to `None`. + min_count: Minimum number of erased rectangles. Defaults to `1`. + max_count: Maximum number of erased rectangles. Defaults to `1`. + trials: Maximum number of trials to randomly sample a rectangle that + fulfills constraint. Defaults to `10`. + """ + self._probability = probability + self._min_area = float(min_area) + self._max_area = float(max_area) + self._min_log_aspect = math.log(min_aspect) + self._max_log_aspect = math.log(max_aspect or 1 / min_aspect) + self._min_count = min_count + self._max_count = max_count + self._trials = trials + + def distort(self, video: tf.Tensor) -> tf.Tensor: + """Applies RandomErasing to video. + + Args: + video (tf.Tensor): Of shape [temporal, height, width, 3] representing a + video. + + Returns: + tf.Tensor: The augmented version of video. + """ + uniform_random = tf.random.uniform(shape=[], minval=0., maxval=1.0) + mirror_cond = tf.less(uniform_random, self._probability) + video = tf.cond(mirror_cond, lambda: self._erase(video), lambda: video) + return video + + @tf.function + def _erase(self, video: tf.Tensor) -> tf.Tensor: + """Erase an area.""" + if self._min_count == self._max_count: + count = self._min_count + else: + count = tf.random.uniform( + shape=[], + minval=int(self._min_count), + maxval=int(self._max_count - self._min_count + 1), + dtype=tf.int32) + + image_height = tf.shape(video)[1] + image_width = tf.shape(video)[2] + area = tf.cast(image_width * image_height, tf.float32) + + for _ in range(count): + # Work around since break is not supported in tf.function + is_trial_successfull = False + for _ in range(self._trials): + if not is_trial_successfull: + erase_area = tf.random.uniform( + shape=[], + minval=area * self._min_area, + maxval=area * self._max_area) + aspect_ratio = tf.math.exp( + tf.random.uniform( + shape=[], + minval=self._min_log_aspect, + maxval=self._max_log_aspect)) + + half_height = tf.cast( + tf.math.round(tf.math.sqrt(erase_area * aspect_ratio) / 2), + dtype=tf.int32) + half_width = tf.cast( + tf.math.round(tf.math.sqrt(erase_area / aspect_ratio) / 2), + dtype=tf.int32) + + if 2 * half_height < image_height and 2 * half_width < image_width: + center_height = tf.random.uniform( + shape=[], + minval=0, + maxval=int(image_height - 2 * half_height), + dtype=tf.int32) + center_width = tf.random.uniform( + shape=[], + minval=0, + maxval=int(image_width - 2 * half_width), + dtype=tf.int32) + + video = _fill_rectangle_video( + video, + center_width, + center_height, + half_width, + half_height, + replace=None) + + is_trial_successfull = True + return video + + +def random_erasing(frames: tf.Tensor, + probability: float = 0.25, min_area: float = 0.02, + max_area: float = 1 / 3, min_aspect: float = 0.3, + max_aspect: Optional[float] = None, min_count=1, + max_count=1, trials=10): + + """Applies RandomErasing to a video. + + Args: + frames: A Tensor of dimension [timesteps, input_h, input_w, channels]. + probability: Probability of augmenting the image. Defaults to `0.25`. + min_area: Minimum area of the random erasing rectangle. Defaults to + `0.02`. + max_area: Maximum area of the random erasing rectangle. Defaults to `1/3`. + min_aspect: Minimum aspect rate of the random erasing rectangle. Defaults + to `0.3`. + max_aspect: Maximum aspect rate of the random erasing rectangle. Defaults + to `None`. + min_count: Minimum number of erased rectangles. Defaults to `1`. + max_count: Maximum number of erased rectangles. Defaults to `1`. + trials: Maximum number of trials to randomly sample a rectangle that + fulfills constraint. Defaults to `10`. + Returns: + tf.Tensor: The augmented version of video. + """ + random_eraser = RandomErasing(probability, min_area, max_area, min_aspect, + max_aspect, min_count, max_count, trials) + return random_eraser.distort(frames) + + +def crop_resize( + frames: tf.Tensor, + output_h: int, + output_w: int, + num_frames: int, + num_channels: int, + area_range=(0.3, 1), + unused_state=None, + aspect_ratio=(0.5, 2.0), + resize_method: str = tf.image.ResizeMethod.BICUBIC, + resize_antialias: bool = False, +) -> tf.Tensor: + """First crop clip with jittering and then resizes to (output_h, output_w). + + Args: + frames: A Tensor of dimension [timesteps, input_h, input_w, channels]. + output_h: Size of the height of output. + output_w: Size of the width of output. + num_frames: Number of input frames per clip. + num_channels: Number of channels of the clip. + area_range: Random crop will preserve this proportion of the area of the + original frame. + unused_state: Argument included to be compatible with DeepMind Video Reader + preprocessing pipeline functions which pass in a state variable. + aspect_ratio: Aspect ratio range of area based random resizing. + resize_method: Method for resizing the frames. + resize_antialias: If True, apply anti-aliasing when resizing. + + Returns: + A Tensor of shape [timesteps, output_h, output_w, channels] of type + frames.dtype. + """ + + shape = tf.shape(frames) + seq_len, channels = int(shape[0]), int(shape[3]) + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + factor = output_w / output_h + aspect_ratio = (aspect_ratio[0] * factor, aspect_ratio[1] * factor) + + sample_distorted_bbox = tf.image.sample_distorted_bounding_box( + shape[1:], + bounding_boxes=bbox, + min_object_covered=0.1, + aspect_ratio_range=aspect_ratio, + area_range=area_range, + max_attempts=100, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_distorted_bbox + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + size = tf.convert_to_tensor((seq_len, target_height, target_width, channels)) + offset = tf.convert_to_tensor((0, offset_y, offset_x, 0)) + + frames = tf.slice(frames, offset, size) + frames = tf.cast( + tf.image.resize( + frames, + (output_h, output_w), + method=resize_method, + antialias=resize_antialias, + ), + frames.dtype, + ) + frames.set_shape((num_frames, output_h, output_w, num_channels)) + return frames + + +def simclr_aug_fn(frames, num_frames): + """Applies the Simclr Augment policy to one video clip. + + Args: + frames: `Tensor` of shape [timesteps, height, width, 3]. + num_frames: number of frames. + + Returns: + A Tensor of shape [timesteps, output_h, output_w, channels] being random + augmented with the same operation. + """ + + def random_color_jitter(image, p=1.0): + + def _transform(image): + color_jitter_t = functools.partial( + simclr_data.color_jitter, strength=0.75) + image = simclr_data.random_apply(color_jitter_t, p=0.8, x=image) + return simclr_data.random_apply(simclr_data.to_grayscale, p=0.2, x=image) + + return simclr_data.random_apply(_transform, p=p, x=image) + + frame_list = tf.unstack(frames, num_frames, 0) + # Temporally random version + # simclr_aug_frame_list = [] + # for image in frame_list: + # image = random_color_jitter(image) + # simclr_aug_frame_list.append(image) + # return tf.stack(simclr_aug_frame_list, axis=0) + + # Temporally consistent version + big_image = tf.concat(frame_list, axis=0) # [t*h, w, c] + big_image = random_color_jitter(big_image) + simclr_aug_frame_list = tf.split(big_image, num_or_size_splits=num_frames) + return tf.stack(simclr_aug_frame_list, axis=0) # [t, h, w, c] + + +def batch_random_blur(images, height, width, blur_probability=0.5): + """Random blur to all frames. + + All frames have a blur applied to them, or all do not. + + Args: + images: `Tensor` of shape [timesteps, height, width, 3].. + height: the height of image. + width: the width of image. + blur_probability: the probaility to apply the blur operator. + + Returns: + Blurred images. + """ + + def generate_selector(p, bsz): + shape = [bsz, 1, 1, 1] + selector = tf.cast( + tf.less(tf.random.uniform(shape, 0, 1, dtype=tf.float32), p), + tf.float32) + return selector + + images_new = simclr_data.random_blur(images, height, width, p=1.) + # All frames have augmentation applied, or not. + selector = generate_selector(blur_probability, 1) + images = images_new * selector + images * (1 - selector) + images = tf.clip_by_value(images, 0., 1.) + + return images + + +def random_solarization(image, p=0.2): + + def _transform(image): + image = image * tf.cast(tf.less(image, 0.5), tf.float32) + ( + 1.0 - image) * tf.cast(tf.greater_equal(image, 0.5), tf.float32) + return image + + return simclr_data.random_apply(_transform, p=p, x=image) + + +def random_time_reverse(image, p=0.5): + + def _transform(image): + return image[::-1, :, :, :] + + return simclr_data.random_apply(_transform, p=p, x=image) + + +def simclr_style_augmentation(frames, height, width, zero_centre): + """Applies SimCLR-style random augmentations to frames. + + Args: + frames: `Tensor` of shape [timesteps, height, width, 3]. + height: Image height. + width: Image width. + zero_centre: Bool. If true, frames are between [-1. 1]. Otherwise, they are + in the range [0, 1] + + Returns: + A Tensor of shape [timesteps, height, width, channels] being random + augmented with the same operation. + """ + num_frames = frames.shape[0] + frames = simclr_aug_fn(frames, num_frames) + blur_frames = batch_random_blur(frames, height, width) + solarize_frames = random_solarization(blur_frames) + reversed_frames = random_time_reverse(solarize_frames) + reversed_frames = tf.clip_by_value(reversed_frames, 0., 1.) + + if zero_centre: + return reversed_frames * 2.0 - 1.0 + else: + return reversed_frames + + +def deterministic_crop(images, size, spatial_idx): + """Takes a deterministic crop of input images. + + Args: + images: `Tensor` of shape shape [t, h, w, c] + size: Integer ; size of height and width to crop the images. + spatial_idx: 0, 1, or 2 for left, center, and right crop if width is larger + than height. Or 0, 1, or 2 for top, center, and bottom crop if height is + larger than width. + + Returns: + cropped: `Tensor` of shape [t, crop_size, crop_size, c] + """ + assert spatial_idx in [0, 1, 2] + height, width = tf.shape(images)[1], tf.shape(images)[2] + + y_offset = tf.cast(tf.math.ceil((height - size) / 2), tf.int32) + x_offset = tf.cast(tf.math.ceil((width - size) / 2), tf.int32) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + + cropped = tf.slice(images, [0, y_offset, x_offset, 0], [-1, size, size, -1]) + + return cropped + + +def three_spatial_crops(images, crop_size): + """Returns three spatial crops of the same frame, as done by SlowFast. + + This enables testing using the same protocol as prior works. ie + (https://arxiv.org/abs/1812.03982, https://arxiv.org/abs/1904.02811, + https://arxiv.org/abs/2004.04730) + If width > height, takes left, centre and right crop. + If height > width, takes top, middle and bottom crop. + + Args: + images: `Tensor` of shape [t, h, w, c] + crop_size: The size to crop from the images + + Returns: + `Tensor` of shape [3 * t, h, w, c] + """ + + result = [] + for spatial_index in range(3): + images_cropped = deterministic_crop(images, crop_size, spatial_index) + result.append(images_cropped) + + return tf.concat(result, axis=0) + + +def additional_augmentations( + ds_factory, + augmentation_params, + crop_size, + num_frames, + zero_centering, + rgb_feature_name=None, + resize_method: str = tf.image.ResizeMethod.BICUBIC, + resize_antialias: bool = False, +): + """Apply additional data augmentations in the DMVR pre-processsing graph.""" + + if not rgb_feature_name: + rgb_feature_name = builders.IMAGE_FEATURE_NAME + + do_simclr_crop_resize = augmentation_params.get('do_simclr_crop_resize', + False) + do_simclr_style_augmentations = augmentation_params.get( + 'do_simclr_style_augmentations', False) + do_rand_augment = augmentation_params.get('do_rand_augment', False) + do_color_augment = augmentation_params.get('do_color_augment', False) + do_jitter_scale = augmentation_params.get('do_jitter_scale', False) + do_random_erasing = augmentation_params.get('do_random_erasing', False) + + if do_simclr_crop_resize and do_jitter_scale: + logging.warning('Only doing simclr_crop_resize.' + 'Not compatible with jitter_scale') + + if do_simclr_crop_resize: + area_range = (augmentation_params.get('simclr_area_lower_bound', 0.5), 1) + aspect_ratio = augmentation_params.get('aspect_ratio_crop', (0.5, 2.0)) + + # Remove resize_smallest and Replace random_crop with crop_resize + ds_factory.preprocessor_builder.remove_fn( + f'{rgb_feature_name}_resize_smallest') + # To replace random_crop with the crop_resize we need to find out which + # function comes next, as not all datasets have the same list of + # preprocessing functions (e.g. SSv2 doesn't have a random_flip) + randcrop_fn_name = f'{rgb_feature_name}_random_crop' + fns_list = ds_factory.preprocessor_builder.get_summary() + idx = [i for i, fd in enumerate(fns_list) if fd.fn_name == randcrop_fn_name] + if not idx: + raise ValueError(f'No {randcrop_fn_name} in Preprocessing Builder.') + next_fn_name = fns_list[idx[0] + 1].fn_name + ds_factory.preprocessor_builder.remove_fn(randcrop_fn_name) + ds_factory.preprocessor_builder.add_fn( + functools.partial( + crop_resize, + num_frames=num_frames, + output_h=crop_size, + output_w=crop_size, + num_channels=3, + area_range=area_range, + aspect_ratio=aspect_ratio, + resize_method=resize_method, + resize_antialias=resize_antialias, + ), + feature_name=rgb_feature_name, + fn_name=f'{rgb_feature_name}_crop_resize', + add_before_fn_name=next_fn_name, + ) + + elif do_jitter_scale: + ds_factory.preprocessor_builder.add_fn( + functools.partial( + dmvr_processors.scale_jitter_augm, + min_scale_factor=augmentation_params.scale_min_factor, + max_scale_factor=augmentation_params.scale_max_factor, + prob=augmentation_params.prob_scale_jitter), + feature_name=rgb_feature_name, + fn_name=f'{rgb_feature_name}_jitter_scale', + add_before_fn_name=f'{rgb_feature_name}_random_crop') + + if do_simclr_style_augmentations and do_color_augment: + logging.warning('Only doing simclr_style_augmentations as it includes' + 'color augmentations') + + if sum([do_rand_augment, do_simclr_style_augmentations, do_color_augment + ]) > 1: + logging.warning('Priority for different augmentation functions is:' + '1) rand_augment. 2) simclr_style_augment.' + '3) colour_augment. Only one is performed.') + + if do_rand_augment: + logging.info('Adding rand_augment') + ds_factory.preprocessor_builder.add_fn( + functools.partial( + distort_image_with_randaugment, + num_layers=augmentation_params.rand_augment_num_layers, + magnitude=augmentation_params.rand_augment_magnitude, + ), + feature_name=rgb_feature_name, + fn_name=f'{rgb_feature_name}_rand_augment', + add_before_fn_name=f'{rgb_feature_name}_normalize') + elif do_simclr_style_augmentations: + # Add additional augmentations at the end + logging.info('Adding simclr_style augmentation') + ds_factory.preprocessor_builder.add_fn( + functools.partial( + simclr_style_augmentation, + height=crop_size, + width=crop_size, + zero_centre=zero_centering), rgb_feature_name) + elif do_color_augment: + logging.info('Adding color_augment') + ds_factory.preprocessor_builder.add_fn( + functools.partial( + dmvr_processors.color_default_augm, + zero_centering_image=zero_centering, + prob_color_augment=augmentation_params.prob_color_augment, + prob_color_drop=augmentation_params.prob_color_drop), + rgb_feature_name) + + if do_random_erasing: + logging.info('Adding random erasing') + random_erasing_prob = augmentation_params.get('random_erasing_prob', 0.25) + ds_factory.preprocessor_builder.add_fn( + functools.partial(random_erasing, probability=random_erasing_prob), + rgb_feature_name) + + return ds_factory + + +def random_sample_sequence_with_centre( + sequence: tf.Tensor, + num_steps: int, + stride: int = 1, + seed: Optional[int] = None, + state: Optional[builders.ProcessorState] = None) -> tf.Tensor: + """Samples a single segment of size `num_steps` from a given sequence. + + The segment is randomly chosen such that it contains the middle element + of the sequence. + + Args: + sequence: Any tensor where the first dimension is timesteps. + num_steps: Number of steps (e.g. frames) to take. + stride: Distance to sample between timesteps. + seed: A deterministic seed to use when sampling. + state: A mutable dictionary where keys are strings. The dictionary might + contain 'sample_offset_proportion' as key with metadata useful for + sampling. It will be modified with added metadata if needed. This can be + used to keep consistency between sampling of different sequences. + + Returns: + A single tensor with first dimension `num_steps` with the sampled segment. + """ + sequence_length = tf.shape(input=sequence)[0] + offset_lower_bound = tf.maximum(sequence_length / 2 - num_steps * stride, 0) + offset_upper_bound = sequence_length / 2 + + offset = tf.random.uniform( + (), + minval=tf.cast(offset_lower_bound, dtype=tf.int32), + maxval=tf.cast(offset_upper_bound, dtype=tf.int32), + dtype=tf.int32, + seed=seed) # Samples from [lower_bound, upper_bound) + + indices = dmvr_processors.sample_or_pad_sequence_indices( + sequence=sequence, + num_steps=num_steps, + repeat_sequence=True, # Will repeat the sequence if we request more. + stride=stride, + offset=offset) + indices.set_shape((num_steps,)) + output = tf.gather(sequence, indices) + + if state is not None: + # Update state. + sample_offset_proportion = ( + tf.cast(offset, tf.float32) / tf.cast(sequence_length, tf.float32)) + state['sample_offset_proportion'] = sample_offset_proportion + + return output + + +def cutout(big_image, pad_size, num_frames, replace=0) -> tf.Tensor: + """Apply cutout (https://arxiv.org/abs/1708.04552) to image. + + This operation applies a (2*pad_size x 2*pad_size) mask of zeros to + a random location within `img`. The pixel values filled in will be of the + value `replace`. The located where the mask will be applied is randomly + chosen uniformly over the whole image. + + Args: + big_image: An image Tensor of type uint8. Shape is [t * h, w, c] + pad_size: Specifies how big the zero mask that will be generated is that is + applied to the image. The mask will be of size (2*pad_size x 2*pad_size). + num_frames: Specifies the t dimension in the input shape. + replace: What pixel value to fill in the image in the area that has the + cutout mask applied to it. + + Returns: + An image Tensor that is of type uint8. + """ + big_image_shape = _get_shape(big_image) + image = tf.reshape(big_image, [ + num_frames, big_image_shape[0] // num_frames, big_image_shape[1], + big_image_shape[2] + ]) + image_height = tf.shape(image)[1] + image_width = tf.shape(image)[2] + + # Sample the center location in the image where the zero mask will be applied. + cutout_center_height = tf.random.uniform( + shape=[], minval=0, maxval=image_height, dtype=tf.int32) + + cutout_center_width = tf.random.uniform( + shape=[], minval=0, maxval=image_width, dtype=tf.int32) + + lower_pad = tf.maximum(0, cutout_center_height - pad_size) + upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) + left_pad = tf.maximum(0, cutout_center_width - pad_size) + right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) + + cutout_shape = [ + image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad) + ] + padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] + mask = tf.pad( + tf.zeros(cutout_shape, dtype=image.dtype), + padding_dims, + constant_values=1) + mask = tf.expand_dims(mask, -1) + mask = tf.expand_dims(mask, 0) + mask = tf.tile(mask, [num_frames, 1, 1, 3]) + image = tf.where( + tf.equal(mask, 0), + tf.ones_like(image, dtype=image.dtype) * replace, image) + + big_image = tf.reshape(image, [num_frames * image_height, image_width, 3]) + return big_image + + +NAME_TO_FUNC = { + 'AutoContrast': augment.autocontrast, + 'Equalize': augment.equalize, + 'Invert': augment.invert, + # 'Rotate': wrapped_rotate, + 'Posterize': augment.posterize, + 'Solarize': augment.solarize, + 'SolarizeAdd': augment.solarize_add, + 'Color': augment.color, + 'Contrast': augment.contrast, + 'Brightness': augment.brightness, + 'Sharpness': augment.sharpness, + # 'ShearX': shear_x, + # 'ShearY': shear_y, + # 'TranslateX': translate_x, + # 'TranslateY': translate_y, + 'Cutout': cutout, +} + +# Functions that have a 'replace' parameter +REPLACE_FUNCS = frozenset({ + 'Rotate', + 'TranslateX', + 'ShearX', + 'ShearY', + 'TranslateY', + 'Cutout', +}) + + +def _parse_policy_info(name, prob, level, replace_value, cutout_const, + translate_const): + """Return the function that corresponds to `name` and update `level` param.""" + func = NAME_TO_FUNC[name] + args = augment.level_to_arg(cutout_const, translate_const)[name](level) + + if name in REPLACE_FUNCS: + # Add in replace arg if it is required for the function that is called. + args = tuple(list(args) + [replace_value]) + + return func, prob, args + + +def distort_image_with_randaugment(frames, + num_layers, + magnitude, + cutout_const=40, + translate_const=100): + """Applies the RandAugment policy to `image`. + + The original rand_augment implementation is for images. To be temporally + consistent in video, we + -- Reshape the video clip [t, h, w, c] to [t * h, w, c] + -- Only apply functions that do not depend on spatial extent (ie rotate, + shear, translate) + -- We do, however, use a modified cutout. + + Args: + frames: `Tensor` of shape [t, h, w, 3] representing an image. + num_layers: Integer, the number of augmentation transformations to apply + sequentially to an image. Represented as (N) in the paper. Usually best + values will be in the range [1, 3]. + magnitude: Integer, shared magnitude across all augmentation operations. + Represented as (M) in the paper. Usually best values are in the range [5, + 10]. + cutout_const: multiplier for applying cutout. + translate_const: multiplier for applying translation. + + Returns: + The augmented version of `frames`. + """ + available_ops = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Posterize', + 'Solarize', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'Cutout', + 'SolarizeAdd', + # 'Rotate', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', + ] + + input_shape = _get_shape(frames) + num_frames = input_shape[0] + image = tf.reshape(frames, [-1, frames.shape[2], frames.shape[3]]) + input_image_type = image.dtype + + if input_image_type != tf.uint8: + image = tf.clip_by_value(image, 0.0, 255.0) + image = tf.cast(image, dtype=tf.uint8) + + replace_value = [128] * 3 + min_prob, max_prob = 0.2, 0.8 + + for _ in range(num_layers): + op_to_select = tf.random.uniform([], + maxval=len(available_ops) + 1, + dtype=tf.int32) + + branch_fns = [] + for (i, op_name) in enumerate(available_ops): + prob = tf.random.uniform([], + minval=min_prob, + maxval=max_prob, + dtype=tf.float32) + func, _, args = _parse_policy_info(op_name, prob, magnitude, + replace_value, cutout_const, + translate_const) + + if op_name == 'Cutout': + args = (args[0], num_frames) + + branch_fns.append(( + i, + # pylint:disable=g-long-lambda + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args))) + # pylint:enable=g-long-lambda + + image = tf.switch_case( + branch_index=op_to_select, + branch_fns=branch_fns, + default=lambda: tf.identity(image)) + + image = tf.cast(image, dtype=input_image_type) + return tf.reshape(image, input_shape) diff --git a/scenic/main.py b/scenic/main.py new file mode 100644 index 0000000000000000000000000000000000000000..3f459e6c29331035755b7092fdbf195a84b96370 --- /dev/null +++ b/scenic/main.py @@ -0,0 +1,66 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main file for Scenic.""" + +from absl import flags +from absl import logging +from clu import metric_writers +from flax.training import checkpoints +import jax +import jax.numpy as jnp +import ml_collections +from scenic import app +from scenic.model_lib import models +from scenic.train_lib import train_utils +from scenic.train_lib import trainers + + +FLAGS = flags.FLAGS + + +def main(rng: jnp.ndarray, config: ml_collections.ConfigDict, workdir: str, + writer: metric_writers.MetricWriter) -> None: + """Main function for Scenic.""" + + model_cls = models.get_model_cls(config.model_name) + data_rng, rng = jax.random.split(rng) + + if config.checkpoint: + # When restoring from a checkpoint, change the dataset seed to ensure that + # the example order is new. With deterministic data, this ensures enough + # randomization and in the future with deterministic data + random access, + # we can feed the global step to the dataset loader to always continue + # reading the rest of the data if we resume a job that was interrupted. + checkpoint_path = checkpoints.latest_checkpoint(workdir) + logging.info('CHECKPOINT PATH: %s', checkpoint_path) + if checkpoint_path is not None: + global_step = train_utils.checkpoint_path_step(checkpoint_path) or 0 + logging.info('Folding global_step %s into dataset seed.', global_step) + data_rng = jax.random.fold_in(data_rng, global_step) + + dataset = train_utils.get_dataset( + config, data_rng, dataset_service_address=FLAGS.dataset_service_address) + + trainers.get_trainer(config.trainer_name)( + rng=rng, + config=config, + model_cls=model_cls, + dataset=dataset, + workdir=workdir, + writer=writer) + + +if __name__ == '__main__': + app.run(main=main) diff --git a/scenic/model_lib/README.md b/scenic/model_lib/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ee79cd4d80610a3442897671d9f1a33b8d476594 --- /dev/null +++ b/scenic/model_lib/README.md @@ -0,0 +1,109 @@ + +## Scenic `BaseModel` +A solution usually has several parts: data/task pipeline, model architecture, +losses and metrics, training and evaluation, etc. Given that much of research +done in Scenic is trying out different architectures, Scenic introduces the +concept of `model`, to facilitate plug-in/plug-out experiments. A Scenic model +is defined as the network architecture plus the losses that are used to update +the weights of the network as well as metrics that are used to evaluate the +output of the network. This is implemented as `BaseModel`. + +`BaseModel` is an abstract class with three members: `get_metrics_fn`, +`loss_fn`, and a `build_flax_model`. + +`get_metrics_fn` returns a callable function, `metric_fn`, that calculates the +metrics and returns a dictionary. The metric function computes `f(x_i, y_i)` on +a mini-batch, it has API: + +```python +metric_fn(logits, label, weights) +``` + +The trainer will then aggregate and compute the mean across all samples +evaluated. + +`loss_fn` is a function of API: + +```python +loss = loss_fn(logits, batch, +model_params=None) +``` + +And finally a `flax_model` is returned from the `build_flax_model` function. A +typical usage pattern will be: + +```python +model_cls = model_lib.models.get_model_cls('fully_connected_classification') +model = model_cls(config, dataset.meta_data) +flax_model = model.build_flax_model +dummy_input = jnp.zeros(input_shape, model_input_dtype) +model_state, params = flax_model.init( + rng, dummy_input, train=False).pop('params') +``` + +And this is how to call the model: + +```python +variables = {'params': params, **model_state} logits, +new_model_state = flax_model.apply(variables, inputs, ...) +``` + +The abstract classes defining Scenic models, including `BaseModel` that defines +the Scenic model as well as `ClassificationModel`, +`MultiLabelClassificationModel`, `EncoderDecoderModel`, `SegmentationModelthat` +that define losses and metrics for classification, seq2seq, and segmentation +tasks are defined in `model_lib/base_models`. A Scenic project can define a new +base-class based on the task, metrics or overwrite the existing one when it is +needed. + +Also, it is important to say that this design pattern, although recommended, is +not forced and there is no issue deviating from such structure, as some projects +in Scenic already do that. + +## Implementing loss and metrics with data parallelism +In Scenic, all the default training loops are designed to support data +parallelism. To do so, we have to be careful about our loss +and metrics calculations. + +When training on multiple devices on multiple hosts, the gradient calculations +are handled locally on each device, given the examples in the device batch. So, +in the loss function, we simply "average" over the loss of all examples in that +device and return a **scalar** value, indicating the loss in that device (Check +out `weighted_softmax_cross_entropy` loss in the [base_models/model_lib.py](./base_models/model_lib.py) +as an example). Then, in the training loop, we compute the gradient on each +device given the loss on that device. Then we **average** over the gradient from +all devices in all hosts: + +```python +grad = jax.lax.pmean(grad, axis_name='batch') +``` + +Note that the `pmean` operation is synchronised across all hosts. + +For metrics, however, the averaging is not done locally to make sure that we +account for actual number of examples in the partial last batch of +test/validation sets. +So each device returns two items: (1) the sum of the "per-example" value of that +metric and (2) number of actual examples processed by that device (to be used +for normalizing the value of that metric). Then, we **sum** over these two items +over all devices in all hosts (check out `psum_metric_normalizer` function +in [base_models/model_lib.py](./base_models/model_lib.py) and pass a tuple of +two scalars for each metric ``. +Then, the summary writer uses the sum and the normalizer to compute the final +value of the metric. +So if you implement a new metric, you should be careful of returning the sum +and normalizer instead of the average of metric value over the examples in the +device (local) batch to guarantee the correctness of metrics' calculation. + +This might seem a bit complicated, however, this is necessary as this carefully +accounts for the potential partial last batch in the test/validation splits and +guarantees correct computation of metrics. More precisely, if we average +locally and compute the mean of local averages, the batches with less example +would contribute to the final mean as much as full batches. + +Note that there are metrics that do not decompose across different examples, +and cannot be computed as `sum(metric_val)/N`, like Mean Average +Precision. For such metrics, we need a special procedure to bring all the +`` pairs to the host and then compute the metrics we want. +You can look at [DETR implementation](../projects/baselines/detr) to learn more +about how this can be implemented using `lax.all_gather`. diff --git a/scenic/model_lib/__init__.py b/scenic/model_lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/__pycache__/__init__.cpython-310.pyc b/scenic/model_lib/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/base_models/__init__.py b/scenic/model_lib/base_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/base_models/base_model.py b/scenic/model_lib/base_models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/base_models/box_utils.py b/scenic/model_lib/base_models/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/base_models/classification_model.py b/scenic/model_lib/base_models/classification_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/base_models/encoder_decoder_model.py b/scenic/model_lib/base_models/encoder_decoder_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/base_models/model_utils.py b/scenic/model_lib/base_models/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/base_models/multilabel_classification_model.py b/scenic/model_lib/base_models/multilabel_classification_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/base_models/regression_model.py b/scenic/model_lib/base_models/regression_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/base_models/segmentation_model.py b/scenic/model_lib/base_models/segmentation_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/layers/__init__.py b/scenic/model_lib/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/layers/attention_layers.py b/scenic/model_lib/layers/attention_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/layers/masked_layers.py b/scenic/model_lib/layers/masked_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/layers/nn_layers.py b/scenic/model_lib/layers/nn_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/layers/nn_ops.py b/scenic/model_lib/layers/nn_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/matchers/__init__.py b/scenic/model_lib/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/matchers/common.py b/scenic/model_lib/matchers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/matchers/greedy.py b/scenic/model_lib/matchers/greedy.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/matchers/hungarian.py b/scenic/model_lib/matchers/hungarian.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/matchers/hungarian_cover.py b/scenic/model_lib/matchers/hungarian_cover.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/matchers/hungarian_jax.py b/scenic/model_lib/matchers/hungarian_jax.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/matchers/lazy.py b/scenic/model_lib/matchers/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/matchers/sinkhorn.py b/scenic/model_lib/matchers/sinkhorn.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/models.py b/scenic/model_lib/models.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8e31176e75d890924a61bb3c3d53059784dd17 --- /dev/null +++ b/scenic/model_lib/models.py @@ -0,0 +1,84 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry for the available models we can train.""" + +from typing import Type + +from scenic.model_lib.base_models import base_model +from scenic.projects.baselines import axial_resnet +from scenic.projects.baselines import bit_resnet +from scenic.projects.baselines import fully_connected +from scenic.projects.baselines import hybrid_vit +from scenic.projects.baselines import mixer +from scenic.projects.baselines import resnet +from scenic.projects.baselines import simple_cnn +from scenic.projects.baselines import unet +from scenic.projects.baselines import vit + +ALL_MODELS = {} + +CLASSIFICATION_MODELS = { + 'fully_connected_classification': + fully_connected.FullyConnectedClassificationModel, + 'simple_cnn_classification': + simple_cnn.SimpleCNNClassificationModel, + 'axial_resnet_multilabel_classification': + axial_resnet.AxialResNetMultiLabelClassificationModel, + 'resnet_classification': + resnet.ResNetClassificationModel, + 'resnet_multilabel_classification': + resnet.ResNetMultiLabelClassificationModel, + 'bit_resnet_classification': + bit_resnet.BitResNetClassificationModel, + 'bit_resnet_multilabel_classification': + bit_resnet.BitResNetMultiLabelClassificationModel, + 'vit_multilabel_classification': + vit.ViTMultiLabelClassificationModel, + 'hybrid_vit_multilabel_classification': + hybrid_vit.HybridViTMultiLabelClassificationModel, + 'mixer_multilabel_classification': + mixer.MixerMultiLabelClassificationModel, +} + +SEGMENTATION_MODELS = { + 'simple_cnn_segmentation': simple_cnn.SimpleCNNSegmentationModel, + 'unet_segmentation': unet.UNetSegmentationModel, +} + + +ALL_MODELS.update(CLASSIFICATION_MODELS) +ALL_MODELS.update(SEGMENTATION_MODELS) + + +def get_model_cls(model_name: str) -> Type[base_model.BaseModel]: + """Get the corresponding model class based on the model string. + + API: + ``` + model_builder= get_model_cls('fully_connected') + model = model_builder(config, ...) + ``` + + Args: + model_name: str; Name of the model, e.g. 'fully_connected'. + + Returns: + The model architecture (a flax Model) along with its default config. + Raises: + ValueError if model_name is unrecognized. + """ + if model_name not in ALL_MODELS.keys(): + raise ValueError('Unrecognized model: {}'.format(model_name)) + return ALL_MODELS[model_name] diff --git a/scenic/model_lib/tests/__init__.py b/scenic/model_lib/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/model_lib/tests/test_models.py b/scenic/model_lib/tests/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/README.md b/scenic/projects/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d5e06c0f2f0a654b7c83d8f43e54d31227831bbf --- /dev/null +++ b/scenic/projects/README.md @@ -0,0 +1,248 @@ +## Contents +* [List of projects hosted in Scenic](#list-of-projects-hosted-in-scenic) +* [Scenic projects](#scenic-projects) + + +## List of projects hosted in Scenic + +* [AdaTape](adatape) + + > AdaTape is an adaptive computation transformer with elastic input sequence. + +* [AdversarialTraining](adversarialtraining) + + > Adversarial training is an implementation of modern forms of adversarial + > training that achieved state-of-the-art robustness results on image + > classifications. This includes [AdvProp](https://arxiv.org/abs/1911.09665) + > and (Pyramid Adversarial Training Improves ViT Performance)[https://arxiv.org/abs/2111.15121]. + +* [AVATAR](avatar) + + > [AVATAR](https://gabeur.github.io/avatar-visspeech) is a + > sequence-to-sequence AudioVisual ASR TrAnsformeR which is + > trained end-to-end from spectrograms and full-frame RGB for the task of + > audiovisual speech recognition (AV-ASR). + +* [Audiovisual Masked Autoencoders](av-mae) + + > Audiovisual Masked Autoencoders performs self-supervised learning on + > multiple modalities (audio and video) to improve representation learning + > for both unimodal and multimodal downstream tasks. Details can be found + > in the [paper](https://arxiv.org/abs/2212.05922). + +* [Boundary Attention](boundary_attention) + + > Boundary Attention is differentiable bottom-up model for detecting + > boundaries in high noise at any resolution. It uses a form of local + > attention to infer boundaries that include contours, corners and + > junctions, all without rasterization. Details and a link to + > the paper can be found on its [website](https://boundaryattention.github.io/). + +* [ViViT](vivit) + + > ViViT is a family of pure-transformer based models for video + > classification that achieved state-of-the-art results. + > Details can be found in the [paper](https://arxiv.org/abs/2103.15691). + +* [Tasseo](tasseo) + + > Tasseo is a project that uses transformer based models for aberration + > detection from chromosome karyotype images. + +* [TokenLearner](token_learner) + + > TokenLearner proposes dynamic tokenization of images and videos for faster + > and more accurate video/image processing tasks. More can be found in + > the [paper](https://arxiv.org/abs/2106.11297). + +* [Token Turing Machines](token_turing) + + > Token Turing Machines are a sequential, autoregressive transformer + > architecture with external memory. More can be found in the + > [paper](https://arxiv.org/abs/2106.11297). + +* [FastViT](fast_vit) + + > FastViT is a project that aims at exploring ideas around making ViT faster + > via using [efficient transformers](https://arxiv.org/abs/2009.06732), in + > particular on higher resolution inputs (more tokens and thus longer + > sequences). + +* [Omninet](omninet) + + > Omninet is a transformer model with + > [omni-directional representations](https://arxiv.org/abs/2103.01075). + +* [CLAY](layout_denoise) + + > CLAY is a Transformer-based pipeline for mobile UI layout denoising. Read + > more about this project in CLAY [paper](https://arxiv.org/abs/2201.04100). + +* [LOCA](loca) + + > LOCA ([paper](https://arxiv.org/abs/2212.02400)) is a self-supervised + > method to train spatially-aware vision transformer features. + +* [MatViT](matvit) + > MatViT is a MatFormer ([paper](https://arxiv.org/abs/2310.07707)) based + > nested ViT architecture designed to offer elasticity in a variety of + > deployment constraints, where each Feed Forward Network (FFN) block of a + > MatViT model is jointly optimized with a few nested smaller FFN blocks. + +* [MBT](mbt) + + > MBT presents a transformer based architecture that uses "fusion + > bottlenecks" for modality fusion at multiple layers. + > Details can be found in the [paper](https://arxiv.org/abs/2201.04100). + +* [MTV](mtv) + + > MTV presents a state-of-the-art transformer based architecture for video + > classification. MTV consists of separate encoders to represent different + > views of the input video with lateral connections and a global encoder to + > fuse information across views. More details are in the + > [paper](https://arxiv.org/abs/2201.04288). + +* [OWL-ViT](owl_vit) + + > OWL-ViT is an open-vocabulary object detector that given an image and a + > free-text query, it finds objects matching that query in the image. It can + > also do one-shot object detection, i.e. detect objects based on a single + > example image. More details are in the + > [paper](https://arxiv.org/abs/2205.06230). + +* [NCR](ncr) + + > NCR is a regularization method which encourages the network to make + > similar predictions for similar vectors in the feature space. + > Details can be found in the [paper](https://arxiv.org/abs/2202.02200), + > where we used this method to learn with noisy labels. + +* [PCT](pointcloud) + + > Point Cloud Transformer (PCT) is a Transformer-based model for + > performing inference (classification/segmentation) for point cloud data. + > Details can be found in the [paper](https://arxiv.org/abs/2012.09688). + +* [PolyViT](polyvit) + + > PolyViT is a simple and effective model for co-training a single + > transformer backbone on multiple modalities and tasks, resulting in a + > parameter-efficient model that performs as well or better than models + > trained on single modalities or tasks. + > Details can be found in the [paper](https://arxiv.org/abs/2111.12993). + +* [T5](t5) + + > Wrappers of T5 models in [t5x](https://github.com/google-research/t5x). + +* [Vid2Seq](vid2seq) + + > Vid2Seq is a single-stage dense video captioning model, pre-trained on + > unlabelled narrated videos. + > Details can be found in the [paper](https://arxiv.org/abs/2302.14115). + +* [ObjectViViT](objectvivit) + + > ObjectViViT uses object detection results from external object detectors + > to help action recognition. + > Details can be found in the [paper](https://openaccess.thecvf.com/content/CVPR2023/html/Zhou_How_Can_Objects_Help_Action_Recognition_CVPR_2023_paper.html). + +* [Verbs in action](verbs_in_action) + + > Verbs in action ([paper](https://arxiv.org/abs/2304.06708)) uses LLMs to + > create hard negative pairs for contrastive learning, in order to improve + > the verb understanding of video-text models based on CLIP. + +* [UniVRD](univrd) + + > UniVRD is a bottom-up visual relationship detector built upon pre-trained + > vision and language models. + > Details can be found in the [paper](https://arxiv.org/abs/2303.08998). + +* [UnLoc](unloc) + + > UnLoc proposes a unified architecture for video localization tasks, + > e.g., Temporal Action Localization, Moment Retrieval, and Action + > Segmentation. More details can be found in the [paper](https://openaccess.thecvf.com/content/ICCV2023/papers/Yan_UnLoc_A_Unified_Framework_for_Video_Localization_Tasks_ICCV_2023_paper.pdf). + +* [REVEAL](knowledge_visual_language) + + > REVEAL is an Retrieval-Augmented Visual Language Model that + > learns to retrieve world knowledge from a diverse set of multimodal + > knowledge sources, through end-to-end pre-training. + > Details can be found in the [paper](https://arxiv.org/abs/2212.05221). + + * [PixelLLM](pixel_llm) + > PixelLLM equips large language models with localization capability. + > Details can be found in the [paper](https://arxiv.org/abs/2312.09237). + +* [GER-ALD](gerald) + + > GER-ALD is a novel generative framework for web-scale visual entity + > recognition. We represent each entity by a compact, discriminative and + > semantic code that a generative model learns to auto-regressively decode. + > Details can be found in the [paper](https://arxiv.org/abs/2403.02041). + +* [Streaming Dense Video Captioning](streaming_dvc) + + > Streaming DVC is a framework for dense captioning of long videos. + > Details can be found in the [paper](https://arxiv.org/abs/2404.01297). + +* [Dense Video Object Captioning](densevoc) + + > Dense VOC is an end-to-end model for joint object detection, tracking, + > and captioning in videos. + > Details can be found in the [paper](https://arxiv.org/abs/2306.11729). + + +## Scenic projects +A typical project consists of models, trainers, configs, a runner, and some +utility functions developed for the project. + +### Models +Models are entities that define the network architecture, loss function, and +metrics. Network architectures are built using Flax `nn.Modules`. Common loss +functions and metrics can be included via a +[Base Model](../model_lib/README.md#base_model), or within the project +itself for more specific use-cases. + +To be accessible by the trainer, a model newly-defined by a project needs to be +registered *within a specific project*. As an exception, the baseline models +are registered directly in `model_lib.models`. + +### Trainers +Trainers implement the training and evaluation loops of the model. There are +already standard trainers that are provided in Scenic for classification, +segmentation, and adaptation (located in the `train_lib` module). +These trainers are directly registered in `train_lib_deprecated/trainers` and +given the careful optimization of these trainers for fast and efficient training +on accelerators (in particular TPUs), they can be forked by projects for further +customization. Projects need to register the new trainers they define within +their project, or they can simply use the standard Scenic trainers when no +modification is needed. + +### Configs +Config files are used to configure experiments. They define (hyper-)parameters +for the selected model, trainer, and dataset (e.g. number of layers, frequency +of logging, etc). + +### Binaries +Binaries bind models, trainers, and datasets together based on the config and +start the training. Usually, this is a `main.py` within the project that also +contains the registry for the project specific models and trainers. Note that +baselines make use of Scenic's default binary `main.py`. + +### Registries +There are three types of objects that can be registered in Scenic: +`model`, `trainer`, and `dataset`. A registry could be any simple data structure +that maps a string name to an object, for instance, a python dictionary. + +Scenic defines a dataset registry that uses ad-hoc importing to lazy-load +the code for the input pipeline of a requested dataset. This registry lives in +`dataset_lib/datasets.py`. There are common trainers and models that are +registered in `train_lib_deprecated/trainers.py` and `model_lib/models.py`. However, +a project can define its own dataset, model, and trainer and make a small +registry for these objects within the project, e.g. in the project's `main.py` +so that the right model, trainer, and dataset are selectable using the +configs specified in the config file. diff --git a/scenic/projects/__init__.py b/scenic/projects/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/__pycache__/__init__.cpython-310.pyc b/scenic/projects/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/adatape/README.md b/scenic/projects/adatape/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/adatape/__init__.py b/scenic/projects/adatape/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/adatape/layers.py b/scenic/projects/adatape/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/adatape/main.py b/scenic/projects/adatape/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/adversarialtraining/README.md b/scenic/projects/adversarialtraining/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/adversarialtraining/classification_adversarialtraining_trainer.py b/scenic/projects/adversarialtraining/classification_adversarialtraining_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/adversarialtraining/main.py b/scenic/projects/adversarialtraining/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/adversarialtraining/train_utils.py b/scenic/projects/adversarialtraining/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/README.md b/scenic/projects/av_mae/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/base_model.py b/scenic/projects/av_mae/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/evaluation_lib.py b/scenic/projects/av_mae/evaluation_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/main.py b/scenic/projects/av_mae/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/mbt.py b/scenic/projects/av_mae/mbt.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/model_utils.py b/scenic/projects/av_mae/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/optimizer_utils.py b/scenic/projects/av_mae/optimizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/registry.py b/scenic/projects/av_mae/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/train_utils.py b/scenic/projects/av_mae/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/trainer.py b/scenic/projects/av_mae/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/trainer_multimodal.py b/scenic/projects/av_mae/trainer_multimodal.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/transfer_trainer.py b/scenic/projects/av_mae/transfer_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/transfer_trainer_multimodal.py b/scenic/projects/av_mae/transfer_trainer_multimodal.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/vit.py b/scenic/projects/av_mae/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/vivit.py b/scenic/projects/av_mae/vivit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/av_mae/vivit_multimodal.py b/scenic/projects/av_mae/vivit_multimodal.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/avatar/README.md b/scenic/projects/avatar/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/avatar/architecture_avatar.png b/scenic/projects/avatar/architecture_avatar.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/avatar/decode.py b/scenic/projects/avatar/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/avatar/generation_trainer.py b/scenic/projects/avatar/generation_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/avatar/main.py b/scenic/projects/avatar/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/avatar/metrics_utils.py b/scenic/projects/avatar/metrics_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/avatar/model_utils.py b/scenic/projects/avatar/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/avatar/models.py b/scenic/projects/avatar/models.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/baselines/README.md b/scenic/projects/baselines/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/baselines/__init__.py b/scenic/projects/baselines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/baselines/axial_resnet.py b/scenic/projects/baselines/axial_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/baselines/bit_resnet.py b/scenic/projects/baselines/bit_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/baselines/fully_connected.py b/scenic/projects/baselines/fully_connected.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/baselines/hybrid_vit.py b/scenic/projects/baselines/hybrid_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/baselines/mixer.py b/scenic/projects/baselines/mixer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/baselines/resnet.py b/scenic/projects/baselines/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/baselines/simple_cnn.py b/scenic/projects/baselines/simple_cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/baselines/unet.py b/scenic/projects/baselines/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/baselines/vit.py b/scenic/projects/baselines/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/boundary_attention/README.md b/scenic/projects/boundary_attention/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/boundary_attention/__init__.py b/scenic/projects/boundary_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/boundary_attention/eval_main.py b/scenic/projects/boundary_attention/eval_main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/boundary_attention/eval_manager.py b/scenic/projects/boundary_attention/eval_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/boundary_attention/main.py b/scenic/projects/boundary_attention/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/boundary_attention/noisy_flower.png b/scenic/projects/boundary_attention/noisy_flower.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/boundary_attention/requirements.txt b/scenic/projects/boundary_attention/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/boundary_attention/rm.png b/scenic/projects/boundary_attention/rm.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/boundary_attention/train_utils.py b/scenic/projects/boundary_attention/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/boundary_attention/trainer.py b/scenic/projects/boundary_attention/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/boundary_attention/types.py b/scenic/projects/boundary_attention/types.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/README.md b/scenic/projects/densevoc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/__init__.py b/scenic/projects/densevoc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/chota.py b/scenic/projects/densevoc/chota.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/densevoc_evaluator.py b/scenic/projects/densevoc/densevoc_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/densevoc_framework.png b/scenic/projects/densevoc/densevoc_framework.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/densevoc_teaser.png b/scenic/projects/densevoc/densevoc_teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/evaluate.py b/scenic/projects/densevoc/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/evaluation_utils.py b/scenic/projects/densevoc/evaluation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/input_pipeline.py b/scenic/projects/densevoc/input_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/input_utils.py b/scenic/projects/densevoc/input_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/main.py b/scenic/projects/densevoc/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/requirements.txt b/scenic/projects/densevoc/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/trainer.py b/scenic/projects/densevoc/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/transforms.py b/scenic/projects/densevoc/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/densevoc/vidstg_evaluator.py b/scenic/projects/densevoc/vidstg_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/fast_vit/README.md b/scenic/projects/fast_vit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/fast_vit/__init__.py b/scenic/projects/fast_vit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/fast_vit/main.py b/scenic/projects/fast_vit/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/fast_vit/model_utils.py b/scenic/projects/fast_vit/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/fast_vit/xvit.py b/scenic/projects/fast_vit/xvit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/gerald/README.md b/scenic/projects/gerald/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/gerald/__init__.py b/scenic/projects/gerald/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/gerald/ger_eval.py b/scenic/projects/gerald/ger_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/gerald/ger_trainer.py b/scenic/projects/gerald/ger_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/gerald/gerald_method.png b/scenic/projects/gerald/gerald_method.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/gerald/input_pipeline.py b/scenic/projects/gerald/input_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/gerald/main.py b/scenic/projects/gerald/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/gerald/prepare_ald_codes.py b/scenic/projects/gerald/prepare_ald_codes.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/gerald/utils.py b/scenic/projects/gerald/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/knowledge_visual_language/README.md b/scenic/projects/knowledge_visual_language/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/knowledge_visual_language/__init__.py b/scenic/projects/knowledge_visual_language/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/knowledge_visual_language/main.py b/scenic/projects/knowledge_visual_language/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/knowledge_visual_language/trainer.py b/scenic/projects/knowledge_visual_language/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/knowledge_visual_language/trainer_memory.py b/scenic/projects/knowledge_visual_language/trainer_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/knowledge_visual_language/trainer_utils.py b/scenic/projects/knowledge_visual_language/trainer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/lang4video/__init__.py b/scenic/projects/lang4video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/layout_denoise/README.md b/scenic/projects/layout_denoise/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/layout_denoise/__init__.py b/scenic/projects/layout_denoise/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/layout_denoise/base_model.py b/scenic/projects/layout_denoise/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/layout_denoise/main.py b/scenic/projects/layout_denoise/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/layout_denoise/model.py b/scenic/projects/layout_denoise/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/layout_denoise/train_utils.py b/scenic/projects/layout_denoise/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/layout_denoise/trainer.py b/scenic/projects/layout_denoise/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/loca/README.md b/scenic/projects/loca/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/loca/__init__.py b/scenic/projects/loca/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/loca/loca.png b/scenic/projects/loca/loca.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/loca/loca_dataset.py b/scenic/projects/loca/loca_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/loca/main.py b/scenic/projects/loca/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/loca/ops.py b/scenic/projects/loca/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/loca/trainer.py b/scenic/projects/loca/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/loca/utils.py b/scenic/projects/loca/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/loca/vit.py b/scenic/projects/loca/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/matvit/README.md b/scenic/projects/matvit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/matvit/__init__.py b/scenic/projects/matvit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/matvit/classification_eval_main.py b/scenic/projects/matvit/classification_eval_main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/matvit/layers.py b/scenic/projects/matvit/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/matvit/main.py b/scenic/projects/matvit/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/matvit/matvit.py b/scenic/projects/matvit/matvit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/matvit/trainer.py b/scenic/projects/matvit/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mbt/README.md b/scenic/projects/mbt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mbt/__init__.py b/scenic/projects/mbt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mbt/bottlenecks.png b/scenic/projects/mbt/bottlenecks.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mbt/main.py b/scenic/projects/mbt/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mbt/model.py b/scenic/projects/mbt/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mbt/model_utils.py b/scenic/projects/mbt/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mbt/requirements.txt b/scenic/projects/mbt/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mbt/train_utils.py b/scenic/projects/mbt/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mbt/trainer.py b/scenic/projects/mbt/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/README.md b/scenic/projects/mtv/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/__init__.py b/scenic/projects/mtv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/config_utils.py b/scenic/projects/mtv/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/config_utils_test.py b/scenic/projects/mtv/config_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/main.py b/scenic/projects/mtv/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/model.py b/scenic/projects/mtv/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/model_test.py b/scenic/projects/mtv/model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/model_utils.py b/scenic/projects/mtv/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/model_utils_test.py b/scenic/projects/mtv/model_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/mtv.png b/scenic/projects/mtv/mtv.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/requirements.txt b/scenic/projects/mtv/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/train_utils.py b/scenic/projects/mtv/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/mtv/trainer.py b/scenic/projects/mtv/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/ncr/README.md b/scenic/projects/ncr/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/ncr/__init__.py b/scenic/projects/ncr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/ncr/base_model.py b/scenic/projects/ncr/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/ncr/classification_trainer.py b/scenic/projects/ncr/classification_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/ncr/loss.py b/scenic/projects/ncr/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/ncr/main.py b/scenic/projects/ncr/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/ncr/requirements.txt b/scenic/projects/ncr/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/ncr/resnet.py b/scenic/projects/ncr/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/ncr/utils.py b/scenic/projects/ncr/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/DATA.md b/scenic/projects/objectvivit/DATA.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/README.md b/scenic/projects/objectvivit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/__init__.py b/scenic/projects/objectvivit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/dataset_utils.py b/scenic/projects/objectvivit/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/datasets.py b/scenic/projects/objectvivit/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/main.py b/scenic/projects/objectvivit/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/model.py b/scenic/projects/objectvivit/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/model_utils.py b/scenic/projects/objectvivit/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/object_attention.py b/scenic/projects/objectvivit/object_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/optimizer_utils.py b/scenic/projects/objectvivit/optimizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/requirements.txt b/scenic/projects/objectvivit/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/train_utils.py b/scenic/projects/objectvivit/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/objectvivit/trainer.py b/scenic/projects/objectvivit/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/omninet/__init__.py b/scenic/projects/omninet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/omninet/main.py b/scenic/projects/omninet/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/omninet/model.py b/scenic/projects/omninet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/omninet/model_utils.py b/scenic/projects/omninet/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/owl_vit/README.md b/scenic/projects/owl_vit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/owl_vit/__init__.py b/scenic/projects/owl_vit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/owl_vit/evaluator.py b/scenic/projects/owl_vit/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/owl_vit/layers.py b/scenic/projects/owl_vit/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/owl_vit/losses.py b/scenic/projects/owl_vit/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/owl_vit/main.py b/scenic/projects/owl_vit/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/owl_vit/matching_base_models.py b/scenic/projects/owl_vit/matching_base_models.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/owl_vit/models.py b/scenic/projects/owl_vit/models.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/owl_vit/requirements.txt b/scenic/projects/owl_vit/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/owl_vit/trainer.py b/scenic/projects/owl_vit/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/owl_vit/utils.py b/scenic/projects/owl_vit/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/performer/performer.py b/scenic/projects/performer/performer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/performer/subquadratic_attention.py b/scenic/projects/performer/subquadratic_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/performer/subquadratic_attention_test.py b/scenic/projects/performer/subquadratic_attention_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/performer/utils.py b/scenic/projects/performer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/README.md b/scenic/projects/pixel_llm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/__init__.py b/scenic/projects/pixel_llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/auto_regressive_decode.py b/scenic/projects/pixel_llm/auto_regressive_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/densecap_evaluator.py b/scenic/projects/pixel_llm/densecap_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/evaluate.py b/scenic/projects/pixel_llm/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/evaluators.py b/scenic/projects/pixel_llm/evaluators.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/main.py b/scenic/projects/pixel_llm/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/partition_utils.py b/scenic/projects/pixel_llm/partition_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/requirements.txt b/scenic/projects/pixel_llm/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/tokenizers.py b/scenic/projects/pixel_llm/tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/train_utils.py b/scenic/projects/pixel_llm/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pixel_llm/trainer.py b/scenic/projects/pixel_llm/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/README.md b/scenic/projects/pointcloud/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/__init__.py b/scenic/projects/pointcloud/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/main.py b/scenic/projects/pointcloud/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/main_s3dis.py b/scenic/projects/pointcloud/main_s3dis.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/main_seg.py b/scenic/projects/pointcloud/main_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/models.py b/scenic/projects/pointcloud/models.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/models_test.py b/scenic/projects/pointcloud/models_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/pointcloud_dataset.py b/scenic/projects/pointcloud/pointcloud_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/s3dis_dataset.py b/scenic/projects/pointcloud/s3dis_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/segmentation_model.py b/scenic/projects/pointcloud/segmentation_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/segmentation_trainer.py b/scenic/projects/pointcloud/segmentation_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/pointcloud/shapenet_dataset.py b/scenic/projects/pointcloud/shapenet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/polyvit/README.md b/scenic/projects/polyvit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/polyvit/__init__.py b/scenic/projects/polyvit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/polyvit/layers.py b/scenic/projects/polyvit/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/polyvit/main.py b/scenic/projects/polyvit/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/polyvit/model.py b/scenic/projects/polyvit/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/polyvit/model_utils.py b/scenic/projects/polyvit/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/polyvit/polyvit_base_model.py b/scenic/projects/polyvit/polyvit_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/polyvit/requirements.txt b/scenic/projects/polyvit/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/polyvit/train_utils.py b/scenic/projects/polyvit/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/polyvit/trainer.py b/scenic/projects/polyvit/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/robust_segvit/README.md b/scenic/projects/robust_segvit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/robust_segvit/__init__.py b/scenic/projects/robust_segvit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/README.md b/scenic/projects/streaming_dvc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/__init__.py b/scenic/projects/streaming_dvc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/caption_evaluator.py b/scenic/projects/streaming_dvc/caption_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/cococap_eval.py b/scenic/projects/streaming_dvc/cococap_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/densecap_evaluator.py b/scenic/projects/streaming_dvc/densecap_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/evaluate.py b/scenic/projects/streaming_dvc/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/main.py b/scenic/projects/streaming_dvc/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/optimizer_utils.py b/scenic/projects/streaming_dvc/optimizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/partition_utils.py b/scenic/projects/streaming_dvc/partition_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/post_processing_utils.py b/scenic/projects/streaming_dvc/post_processing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/requirements.txt b/scenic/projects/streaming_dvc/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/streaming_dvc_teaser.png b/scenic/projects/streaming_dvc/streaming_dvc_teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/train_utils.py b/scenic/projects/streaming_dvc/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/streaming_dvc/trainer.py b/scenic/projects/streaming_dvc/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/svvit/README.md b/scenic/projects/svvit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/svvit/__init__.py b/scenic/projects/svvit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/svvit/classification_trainer.py b/scenic/projects/svvit/classification_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/svvit/inference.py b/scenic/projects/svvit/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/svvit/main.py b/scenic/projects/svvit/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/svvit/metrics.py b/scenic/projects/svvit/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/svvit/transfer_trainer.py b/scenic/projects/svvit/transfer_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/svvit/vit.py b/scenic/projects/svvit/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/svvit/xvit.py b/scenic/projects/svvit/xvit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/t5/README.md b/scenic/projects/t5/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/t5/__init__.py b/scenic/projects/t5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/t5/inspect_model.py b/scenic/projects/t5/inspect_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/t5/layers.py b/scenic/projects/t5/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/t5/model.py b/scenic/projects/t5/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/t5/tokenizer.py b/scenic/projects/t5/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/tasseo/README.md b/scenic/projects/tasseo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/tasseo/classification_trainer.py b/scenic/projects/tasseo/classification_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/tasseo/dataset_utils.py b/scenic/projects/tasseo/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/tasseo/duplex_vit.py b/scenic/projects/tasseo/duplex_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/tasseo/duplex_vit_classification_trainer.py b/scenic/projects/tasseo/duplex_vit_classification_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/tasseo/inference.py b/scenic/projects/tasseo/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/tasseo/main.py b/scenic/projects/tasseo/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/tasseo/train_utils.py b/scenic/projects/tasseo/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/tasseo/transfer_trainer.py b/scenic/projects/tasseo/transfer_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/tasseo/vit.py b/scenic/projects/tasseo/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/tasseo/xvit.py b/scenic/projects/tasseo/xvit.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/token_learner/README.md b/scenic/projects/token_learner/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/token_learner/__init__.py b/scenic/projects/token_learner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/token_learner/main.py b/scenic/projects/token_learner/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/token_learner/model.py b/scenic/projects/token_learner/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/token_turing/README.md b/scenic/projects/token_turing/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/token_turing/model.py b/scenic/projects/token_turing/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/univrd/README.md b/scenic/projects/univrd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/README.md b/scenic/projects/unloc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/action_segmentation_base_model.py b/scenic/projects/unloc/action_segmentation_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/action_segmentation_base_model_test.py b/scenic/projects/unloc/action_segmentation_base_model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/activity_net_eval.py b/scenic/projects/unloc/activity_net_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/config_utils.py b/scenic/projects/unloc/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/encoders.py b/scenic/projects/unloc/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/encoders_test.py b/scenic/projects/unloc/encoders_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/eval_utils.py b/scenic/projects/unloc/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/eval_utils_test.py b/scenic/projects/unloc/eval_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/evaluator.py b/scenic/projects/unloc/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/heads.py b/scenic/projects/unloc/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/heads_test.py b/scenic/projects/unloc/heads_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/main.py b/scenic/projects/unloc/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/metrics.py b/scenic/projects/unloc/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/metrics_test.py b/scenic/projects/unloc/metrics_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/model.py b/scenic/projects/unloc/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/model_test.py b/scenic/projects/unloc/model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/model_utils.py b/scenic/projects/unloc/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/model_utils_test.py b/scenic/projects/unloc/model_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/moment_retrieval_base_model.py b/scenic/projects/unloc/moment_retrieval_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/moment_retrieval_base_model_test.py b/scenic/projects/unloc/moment_retrieval_base_model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/optimizer_utils.py b/scenic/projects/unloc/optimizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/optimizer_utils_test.py b/scenic/projects/unloc/optimizer_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/postprocessing_utils.py b/scenic/projects/unloc/postprocessing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/postprocessing_utils_test.py b/scenic/projects/unloc/postprocessing_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/single_task_trainer.py b/scenic/projects/unloc/single_task_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/temporal_localization_base_model.py b/scenic/projects/unloc/temporal_localization_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/temporal_localization_base_model_test.py b/scenic/projects/unloc/temporal_localization_base_model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/train_utils.py b/scenic/projects/unloc/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/unloc.png b/scenic/projects/unloc/unloc.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/video_text_fusion.py b/scenic/projects/unloc/video_text_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/unloc/video_text_fusion_test.py b/scenic/projects/unloc/video_text_fusion_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/verbs_in_action/README.md b/scenic/projects/verbs_in_action/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/verbs_in_action/__init__.py b/scenic/projects/verbs_in_action/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/verbs_in_action/clip4clip_model.py b/scenic/projects/verbs_in_action/clip4clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/verbs_in_action/kinetics-400-verb-classes.txt b/scenic/projects/verbs_in_action/kinetics-400-verb-classes.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/verbs_in_action/losses.py b/scenic/projects/verbs_in_action/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/verbs_in_action/main.py b/scenic/projects/verbs_in_action/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/verbs_in_action/tfrecord_dataset.py b/scenic/projects/verbs_in_action/tfrecord_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/verbs_in_action/trainer.py b/scenic/projects/verbs_in_action/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/verbs_in_action/utils.py b/scenic/projects/verbs_in_action/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/verbs_in_action/vfc.png b/scenic/projects/verbs_in_action/vfc.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/README.md b/scenic/projects/vid2seq/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/__init__.py b/scenic/projects/vid2seq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/data_utils.py b/scenic/projects/vid2seq/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/dvc_eval.py b/scenic/projects/vid2seq/dvc_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/generate_from_file.py b/scenic/projects/vid2seq/generate_from_file.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/load_utils.py b/scenic/projects/vid2seq/load_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/main.py b/scenic/projects/vid2seq/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/models.py b/scenic/projects/vid2seq/models.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/requirements.txt b/scenic/projects/vid2seq/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/train_utils.py b/scenic/projects/vid2seq/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/trainer.py b/scenic/projects/vid2seq/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vid2seq/vid2seq.png b/scenic/projects/vid2seq/vid2seq.png new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vivit/README.md b/scenic/projects/vivit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vivit/__init__.py b/scenic/projects/vivit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vivit/evaluation_lib.py b/scenic/projects/vivit/evaluation_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vivit/main.py b/scenic/projects/vivit/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vivit/model.py b/scenic/projects/vivit/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vivit/model_utils.py b/scenic/projects/vivit/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vivit/requirements.txt b/scenic/projects/vivit/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vivit/train_utils.py b/scenic/projects/vivit/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/projects/vivit/trainer.py b/scenic/projects/vivit/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/__init__.py b/scenic/train_lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/__pycache__/__init__.cpython-310.pyc b/scenic/train_lib/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/__pycache__/optimizers.cpython-310.pyc b/scenic/train_lib/__pycache__/optimizers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/__pycache__/train_utils.cpython-310.pyc b/scenic/train_lib/__pycache__/train_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/classification_trainer.py b/scenic/train_lib/classification_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..adcd70946c16cbc9be1f7ae227422b1176f76fa9 --- /dev/null +++ b/scenic/train_lib/classification_trainer.py @@ -0,0 +1,423 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training Script.""" + +import functools +from typing import Any, Callable, Dict, Tuple, Optional, Type + +from absl import logging +from clu import metric_writers +from clu import periodic_actions +from clu import platform +import flax +from flax import jax_utils +import flax.linen as nn +import jax +from jax.example_libraries.optimizers import clip_grads +import jax.numpy as jnp +import jax.profiler +import ml_collections +import numpy as np +import optax +from scenic.dataset_lib import dataset_utils +from scenic.model_lib.base_models import base_model +from scenic.train_lib import lr_schedules +from scenic.train_lib import optimizers +from scenic.train_lib import train_utils + +# Aliases for custom types: +Batch = Dict[str, jnp.ndarray] +MetricFn = Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], + Dict[str, Tuple[float, int]]] +LossFn = Callable[[jnp.ndarray, Batch, Optional[jnp.ndarray]], float] +LrFn = Callable[[jnp.ndarray], jnp.ndarray] + +flax.config.update('flax_use_orbax_checkpointing', False) + + +def train_step( + train_state: train_utils.TrainState, + batch: Batch, + *, + flax_model: nn.Module, + loss_fn: LossFn, + lr_fn: LrFn, + metrics_fn: MetricFn, + config: ml_collections.ConfigDict, + debug: Optional[bool] = False +) -> Tuple[train_utils.TrainState, Dict[str, Tuple[float, int]], Dict[str, + Any]]: + """Runs a single step of training. + + Given the state of the training and a batch of data, computes + the loss and updates the parameters of the model. + + Note that in this code, the buffers of the first (train_state) and second + (batch) arguments are donated to the computation. + + Args: + train_state: The state of training including the current global_step, + model_state, rng, params, and optimizer. The buffer of this argument can + be donated to the computation. + batch: A single batch of data. The buffer of this argument can be donated to + the computation. + flax_model: A Flax model. + loss_fn: A loss function that given logits, a batch, and parameters of the + model calculates the loss. + lr_fn: The learning rate fn used for the logging the learning rate. + metrics_fn: A metrics function that given logits and batch of data, + calculates the metrics as well as the loss. + config: Configurations of the experiment. + debug: Whether the debug mode is enabled during training. `debug=True` + enables model specific logging/storing some values using + jax.host_callback. + + Returns: + Updated state of training and computed metrics and some training logs. + """ + training_logs = {} + new_rng, rng = jax.random.split(train_state.rng) + + if config.get('mixup') and config.mixup.alpha: + mixup_rng, rng = jax.random.split(rng, 2) + mixup_rng = train_utils.bind_rng_to_host_device( + mixup_rng, + axis_name='batch', + bind_to=config.mixup.get('bind_to', 'device')) + batch = dataset_utils.mixup( + batch, + config.mixup.alpha, + config.mixup.get('image_format', 'NHWC'), + rng=mixup_rng) + + # Bind the rng to the host/device we are on. + dropout_rng = train_utils.bind_rng_to_host_device( + rng, axis_name='batch', bind_to='device') + + def training_loss_fn(params): + variables = {'params': params, **train_state.model_state} + logits, new_model_state = flax_model.apply( + variables, + batch['inputs'], + mutable=['batch_stats'], + train=True, + rngs={'dropout': dropout_rng}, + debug=debug) + loss = loss_fn(logits, batch, variables['params']) + return loss, (new_model_state, logits) + + compute_gradient_fn = jax.value_and_grad(training_loss_fn, has_aux=True) + (train_cost, (new_model_state, + logits)), grad = compute_gradient_fn(train_state.params) + + del train_cost + # Re-use same axis_name as in the call to `pmap(...train_step...)` below. + grad = jax.lax.pmean(grad, axis_name='batch') + + if config.get('max_grad_norm') is not None: + grad = clip_grads(grad, config.max_grad_norm) + + tx = train_state.tx + if tx is None: + raise ValueError('train_state.tx, the Gradient Transformation, is None') + updates, new_opt_state = tx.update( + grad, train_state.opt_state, train_state.params + ) + new_params = optax.apply_updates(train_state.params, updates) + + training_logs['l2_grads'] = jnp.sqrt( + sum([jnp.vdot(g, g) for g in jax.tree_util.tree_leaves(grad)]) + ) + ps = jax.tree_util.tree_leaves(new_params) + training_logs['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) + us = jax.tree_util.tree_leaves(updates) + training_logs['l2_updates'] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) + # TODO(dehghani): Can we get this from the optimizer instead? + training_logs['learning_rate'] = lr_fn(jnp.asarray([train_state.global_step])) + + metrics = metrics_fn(logits, batch) + + new_train_state = train_state.replace( # pytype: disable=attribute-error + global_step=train_state.global_step + 1, + opt_state=new_opt_state, + params=new_params, + model_state=new_model_state, + rng=new_rng) + + return new_train_state, metrics, training_logs + + +def eval_step( + train_state: train_utils.TrainState, + batch: Batch, + *, + flax_model: nn.Module, + metrics_fn: MetricFn, + debug: Optional[bool] = False +) -> Tuple[Dict[str, Tuple[float, int]], jnp.ndarray]: + """Runs a single step of training. + + Note that in this code, the buffer of the second argument (batch) is donated + to the computation. + + Assumed API of metrics_fn is: + ```metrics = metrics_fn(logits, batch) + where batch is yielded by the batch iterator, and metrics is a dictionary + mapping metric name to a vector of per example measurements. eval_step will + aggregate (by summing) all per example measurements and divide by the + aggregated normalizers. For each given metric we compute: + 1/N sum_{b in batch_iter} metric(b), where N is the sum of normalizer + over all batches. + + Args: + train_state: TrainState, the state of training including the current + global_step, model_state, rng, params and optimizer state. The buffer of + this argument can be donated to the computation. + batch: A single batch of data. a metrics function, that given logits and + batch of data, calculates the metrics as well as the loss. + flax_model: A Flax model. + metrics_fn: A metrics function, that given logits and batch of data, + calculates the metrics as well as the loss. + debug: Whether the debug mode is enabled during evaluation. `debug=True` + enables model specific logging/storing some values using + jax.host_callback. + + Returns: + Calculated metrics and logits. + """ + variables = {'params': train_state.params, **train_state.model_state} + logits = flax_model.apply( + variables, batch['inputs'], train=False, mutable=False, debug=debug) + metrics = metrics_fn(logits, batch) + return metrics, logits + + +def train( + *, + rng: jnp.ndarray, + config: ml_collections.ConfigDict, + model_cls: Type[base_model.BaseModel], + dataset: dataset_utils.Dataset, + workdir: str, + writer: metric_writers.MetricWriter, +) -> Tuple[train_utils.TrainState, Dict[str, Any], Dict[str, Any]]: + """Main training loop lives in this function. + + Given the model class and dataset, it prepares the items needed to run the + training, including the TrainState. + + Args: + rng: Jax rng key. + config: Configurations of the experiment. + model_cls: Model class; A model has a flax_module, a loss_fn, and a + metrics_fn associated with it. + dataset: The dataset that has train_iter, eval_iter, meta_data, and + optionally, test_iter. + workdir: Directory for checkpointing. + writer: CLU metrics writer instance. + + Returns: + train_state that has the state of training (including current + global_step, model_state, rng, and the optimizer), train_summary + and eval_summary which are dict of metrics. These outputs are used for + regression testing. + """ + lead_host = jax.process_index() == 0 + # Build the loss_fn, metrics, and flax_model. + model = model_cls(config, dataset.meta_data) + + # Initialize model. + rng, init_rng = jax.random.split(rng) + (params, model_state, num_trainable_params, + gflops) = train_utils.initialize_model( + model_def=model.flax_model, + input_spec=[(dataset.meta_data['input_shape'], + dataset.meta_data.get('input_dtype', jnp.float32))], + config=config, + rngs=init_rng) + + # Create optimizer. + lr_fn = lr_schedules.get_learning_rate_fn(config) + optimizer_config = optimizers.get_optax_optimizer_config(config) + # If the config is already an optax-compatible config, better call directly: + # optimizers.get_optimizer(config.optimizer_configs, lr_fn) + tx = optimizers.get_optimizer(optimizer_config, lr_fn, params=params) + # We jit this, such that the arrays that are created on the same device as the + # input is, in this case the CPU. Else they'd be on device[0]. + opt_state = jax.jit(tx.init, backend='cpu')(params) + + rng, train_rng = jax.random.split(rng) + + # Create chrono class to track and store training statistics and metadata: + chrono = train_utils.Chrono() + + train_state = train_utils.TrainState( + global_step=0, + opt_state=opt_state, + tx=tx, + params=params, + model_state=model_state, + rng=train_rng, + metadata={'chrono': chrono.save()}) + start_step = train_state.global_step + if config.checkpoint: + train_state, start_step = train_utils.restore_checkpoint( + workdir, train_state) + chrono.load(train_state.metadata['chrono']) + train_state = train_state.replace(metadata={}) + # Replicate the optimizer, state, and rng. + train_state = jax_utils.replicate(train_state) + del params # Do not keep a copy of the initial params. + + # Calculate the total number of training steps. + total_steps, steps_per_epoch = train_utils.get_num_training_steps( + config, dataset.meta_data) + + train_step_pmapped = jax.pmap( + functools.partial( + train_step, + flax_model=model.flax_model, + loss_fn=model.loss_function, + lr_fn=lr_fn, + metrics_fn=model.get_metrics_fn('train'), + config=config, + debug=config.debug_train), + axis_name='batch', + # We can donate both buffers of train_state and train_batch. + donate_argnums=(0, 1), + ) + eval_step_pmapped = jax.pmap( + functools.partial( + eval_step, + flax_model=model.flax_model, + metrics_fn=model.get_metrics_fn('validation'), + debug=config.debug_eval), + axis_name='batch', + # We can donate the eval_batch's buffer. + donate_argnums=(1,), + ) + log_eval_steps = config.get('log_eval_steps') or steps_per_epoch + if not log_eval_steps: + raise ValueError("'log_eval_steps' should be specified in the config.") + checkpoint_steps = config.get('checkpoint_steps') or log_eval_steps + max_checkpoint_keep = config.get('max_checkpoint_keep', 3) + log_summary_steps = config.get('log_summary_steps') or log_eval_steps + + # Ceil rounding such that we include the last incomplete batch. + eval_batch_size = config.get('eval_batch_size', config.batch_size) + total_eval_steps = int( + np.ceil(dataset.meta_data['num_eval_examples'] / eval_batch_size)) + steps_per_eval = config.get('steps_per_eval') or total_eval_steps + + train_metrics, extra_training_logs = [], [] + train_summary, eval_summary = None, None + + chrono.inform(start_step, total_steps, config.batch_size, steps_per_epoch) + logging.info('Starting training loop at step %d.', start_step + 1) + report_progress = periodic_actions.ReportProgress( + num_train_steps=total_steps, + writer=writer, + every_secs=None, + every_steps=config.get('report_progress_step', log_summary_steps), + ) + + def write_note(note): + if lead_host: + platform.work_unit().set_notes(note) + + hooks = [] + if lead_host: + hooks.append(report_progress) + if config.get('xprof', True) and lead_host: + hooks.append(periodic_actions.Profile(num_profile_steps=5, logdir=workdir)) + + if start_step == 0: + step0_log = {'num_trainable_params': num_trainable_params} + if gflops: + step0_log['gflops'] = gflops + writer.write_scalars(1, step0_log) + + write_note(f'First step compilations...\n{chrono.note}') + for step in range(start_step + 1, total_steps + 1): + with jax.profiler.StepTraceAnnotation('train', step_num=step): + train_batch = next(dataset.train_iter) + train_state, t_metrics, t_logs = train_step_pmapped( + train_state, train_batch) + # This will accumulate metrics in TPU memory up to the point that we log + # them. This is no problem for small metrics but may be a problem for + # large (e.g. segmentation) metrics. An alternative is to set + # `log_summary_steps` to a small number, or to use + # `train_utils.unreplicate_and_get` here instead of right before writing + # summaries, but that means in each step, we have data transfer between + # tpu and host, which might slow down the training. + train_metrics.append(t_metrics) + # Additional training logs: learning rate: + t_logs = jax.tree_util.tree_map(jax_utils.unreplicate, t_logs) + extra_training_logs.append(t_logs) + for h in hooks: + h(step) + # Below are once-in-a-while ops -> pause. + ###################### LOG TRAIN SUMMARY ######################## + if ((step % log_summary_steps == 1) or (step == total_steps) or + (lead_host and chrono.warmup)): + chrono.pause(wait_for=(train_metrics)) + if lead_host: + chrono.tick(step, writer, write_note) + # train_metrics is list of a dictionaries of metrics, where the shape of + # the metrics[key] is [n_local_devices]. However, because metric functions + # have a psum, we have already summed across the whole sharded batch, and + # what's returned is n_local_devices copies of the same summed metric. + # So we do unreplicate and fetch them to host using `unreplicate_and_get`. + train_summary = train_utils.log_train_summary( + step=step, + train_metrics=jax.tree_util.tree_map(train_utils.unreplicate_and_get, + train_metrics), + extra_training_logs=jax.tree_util.tree_map(jax.device_get, + extra_training_logs), + writer=writer) + # Reset metric accumulation for next evaluation cycle. + train_metrics, extra_training_logs = [], [] + chrono.resume() + ################### EVALUATION ####################### + if (step % log_eval_steps == 1) or (step == total_steps): + chrono.pause(wait_for=(train_state.params)) + with report_progress.timed('eval'): + eval_metrics = [] + # Sync model state across replicas. + train_state = train_utils.sync_model_state_across_replicas(train_state) + for _ in range(steps_per_eval): + eval_batch = next(dataset.valid_iter) + e_metrics, _ = eval_step_pmapped(train_state, eval_batch) + eval_metrics.append(train_utils.unreplicate_and_get(e_metrics)) + eval_summary = train_utils.log_eval_summary( + step=step, eval_metrics=eval_metrics, writer=writer) + writer.flush() + del eval_metrics + chrono.resume() + ##################### CHECKPOINTING ################### + if ((step % checkpoint_steps == 1 and step > 1) or + (step == total_steps)) and config.checkpoint: + chrono.pause(wait_for=(train_state.params, train_state.opt_state)) + with report_progress.timed('checkpoint'): + train_utils.handle_checkpointing( + train_state, chrono, workdir, max_checkpoint_keep) + chrono.resume() + + # Wait until computations are done before exiting. + train_utils.barrier_across_hosts() + # Return the train and eval summary after last step for regression testing. + assert train_summary is not None + assert eval_summary is not None + return train_state, train_summary, eval_summary diff --git a/scenic/train_lib/lr_schedules.py b/scenic/train_lib/lr_schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0f94404a0ac62c9172580b722ede6774c5d715 --- /dev/null +++ b/scenic/train_lib/lr_schedules.py @@ -0,0 +1,329 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines different learning_rate schedules.""" + +import jax.numpy as jnp +import ml_collections + + +def polynomial_lr_scheduler(step, decay_steps, end_factor, power): + """Same behavior as tf.train.polynomial_decay. + + This is the original formula for this learning rate scheduler: + ``` + end_learning_rate = config['base_learning_rate'] * config['end_factor'] + step = min(config['decay_steps'], step) + decayed_learning_rate = (config['base_learning_rate'] - + end_learning_rate) * ( + 1 - step / config['decay_steps'])**( + config['power']) + end_learning_rate + ``` + We rewrite this as a multiplicative factor for the initial learning rate. + Args: + step: int; Current step. + decay_steps: int; Parameter of the decay function. + end_factor: float; Final lr is: initial lr x end_factor. + power: int; Parameter of the decay function. + + Returns: + Scaling factor applied to the learning rate on the given step. + """ + + decay = step <= decay_steps + decayed_learning_rate = (1 - end_factor) * ( + decay * (1 - step / decay_steps))**(power) + end_factor + return decayed_learning_rate + + +def piecewise_constant_scheduler(step, decay_events, decay_factors): + """Gives a scaling factor based on Piecewise Constant scheduling. + + Args: + step: int; Current step. + decay_events: list(int); List of steps in which a decay is applied. + decay_factors: list(int); List containing the absolute ratio of the decay + applied on the decay events. Note that each element of decay_factors is + absolute (not relative). For example, to decay the learning rate to 0.5 of + its initial value after 100 steps, followed by 0.1 of its *initial value* + after 200 steps, with a plateau of 0.1 of its initial value thereafter, + use decay_events = [100, 200] and decay_factors = [0.5, 0.1]. + + Returns: + Scaling factor applied to the learning rate on the given step. + """ + boundaries = jnp.array(decay_events) + factors = jnp.array([1.0] + decay_factors) + index = jnp.sum(boundaries < step) + ratio = jnp.take(factors, index) + return ratio + + +def piecewise_linear_scheduler(step, decay_events, decay_factors): + """Gives a scaling factor based on Piecewise Linear scheduling. + + Args: + step: int; Current step. + decay_events: list(int); List of steps in which a decay is applied. + decay_factors: list(int); List containing the absolute ratio of the decay + applied on the decay events. Note that each element of decay_factors is + absolute (not relative). For example, to decay the learning rate to 0.5 of + its initial value after 100 steps, followed by 0.1 of its *initial value* + after 200 steps, with a plateau of 0.1 of its initial value thereafter, + use decay_events = [100, 200] and decay_factors = [0.5, 0.1]. + + Returns: + Scaling factor applied to the learning rate on the given step. + """ + boundaries = jnp.array([0] + decay_events + [step]) + factors = jnp.array([1.0] + decay_factors + [decay_factors[-1]]) + index = jnp.sum(boundaries[1:] < step) + m = jnp.take(factors, index + 1) - jnp.take(factors, index) + n = jnp.take(boundaries, index + 1) - jnp.take(boundaries, index) + a = m / jnp.clip(n, 1) + interpolated_factor = ( + a * (step - jnp.take(boundaries, index)) + jnp.take(factors, index)) + return interpolated_factor + + +def linear_warmup_scheduler(step, warmup_steps, alpha=0.): + """Gives a scaling factor based on scheduling with a Linear Warmup. + + Args: + step: int; Current step. + warmup_steps: int; How many steps to warm up for in the warmup schedule. + alpha: float: The minimum value as a fraction of the initial value. + + Returns: + Scaling factor applied to the learning rate on the given step. + """ + if warmup_steps > 0: + return jnp.minimum(1.0, alpha + step * (1.0 - alpha) / warmup_steps) + else: + return 1.0 + + +def decay_every_scheduler(step, steps_per_decay, decay_factor): + """Gives a scaling factor based on scheduling with a decay every n-steps. + + Args: + step: int; Current step. + steps_per_decay: int; How often to decay. + decay_factor: float; The amount to decay. + + Returns: + Scaling factor applied to the learning rate on the given step. + """ + return decay_factor**(step // steps_per_decay) + + +def exponential_decay_scheduler(step, decay_steps, decay_rate, staircase=False): + """Gives a scaling factor based on scheduling with an exponential decay. + + Args: + step: int; Current step. + decay_steps: int; Number of steps to decay over. + decay_rate: float; Rate of exponential decay. + staircase: bool; If True, use integer division in scale-computation. + + Returns: + Scaling factor applied to the learning rate on the given step. + """ + progress = step / float(decay_steps) + if staircase: + progress = jnp.floor(progress) + return jnp.power(decay_rate, progress) + + +def cosine_decay_scheduler(step, steps_per_cycle, t_mul=1, m_mul=1., alpha=0.): + """Gives a scaling factor based on scheduling with a cosine decay. + + Args: + step: int; Current step. + steps_per_cycle: int; Number of steps to reset the decay cycle. + t_mul: int; Used to derive the number of iterations in the i-th period. + m_mul: float; Used to derive the initial learning rate of the i-th period. + alpha: float; The minimum value as a fraction of the initial value. + + Returns: + Scaling factor applied to the learning rate on the given step. + """ + if steps_per_cycle <= 0: + raise ValueError(f'steps_per_cycle must be > 0. Got {steps_per_cycle}.') + progress = step / float(steps_per_cycle) + if t_mul == 1.0: + i_restart = jnp.floor(progress) + progress -= i_restart + else: + i_restart = jnp.floor( + jnp.log(1.0 - progress * (1.0 - t_mul)) / jnp.log(t_mul)) + sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) + progress = (progress - sum_r) / t_mul**i_restart + m_fac = m_mul**i_restart + cosine_decay = jnp.maximum( + 0.0, 0.5 * m_fac * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) + return (1 - alpha) * cosine_decay + alpha + + +def compound_lr_scheduler(config): + """Creates a learning rate scheduler by combining multiple factors. + + Interprets factors in the factors string which can consist of: + * constant: interpreted as the constant value, + * linear_warmup: interpreted as linear warmup until warmup_steps, + * rsqrt_decay: divide by square root of max(step, warmup_steps) + * decay_every: Every k steps decay the learning rate by decay_factor. + * cosine_decay: Cyclic cosine decay. + + For instance, `config['factors'] = 'constant*linear_warmup'` combines the + constant learning rate schedule with a linear warmup. This requires one to + have the following configuration entries: + config['warmup_steps'] and config['base_learning_rate']. + + Args: + config: Relevant config based on the chosen factors. + + Returns: + lr_fn: A function mapping global_step to lr. + """ + + ratio_factors = [n.strip() for n in config['factors'].split('*')] + + def lr_fn(step): + """Step to learning rate function.""" + ratio = 1.0 + for name in ratio_factors: + if name == 'constant': + ratio *= config['base_learning_rate'] + elif name == 'polynomial': + decay_steps = config['decay_steps'] + end_factor = config['end_factor'] + power = config['power'] + ratio *= polynomial_lr_scheduler(step, decay_steps, end_factor, power) + elif name == 'piecewise_constant': + decay_events = config['decay_events'] + decay_factors = config['decay_factors'] + ratio *= piecewise_constant_scheduler(step, decay_events, decay_factors) + + elif name == 'piecewise_linear': + decay_events = config['decay_events'] + decay_factors = config['decay_factors'] + ratio *= piecewise_linear_scheduler(step, decay_events, decay_factors) + + elif name == 'linear_warmup': + warmup_steps = config['warmup_steps'] + warmup_alpha = config.get('warmup_alpha', 0) + ratio *= linear_warmup_scheduler(step, warmup_steps, warmup_alpha) + + elif name == 'rsqrt_decay': + warmup_steps = config.get('warmup_steps', 0.) + timescale = config.get('timescale', 10_000) + shift = timescale - warmup_steps + ratio *= jnp.where(warmup_steps < step, + jnp.sqrt(timescale) / jnp.sqrt(step + shift), 1.) + + elif name == 'decay_every': + steps_per_decay = config['steps_per_decay'] + decay_factor = config['decay_factor'] + ratio *= decay_every_scheduler(step, steps_per_decay, decay_factor) + + elif name == 'exponential_decay': + decay_steps = config['decay_steps'] + decay_rate = config['decay_rate'] + staircase = config.get('staircase', False) + ratio *= exponential_decay_scheduler( + step, decay_steps, decay_rate, staircase=staircase) + + elif name == 'cosine_decay': + steps_per_cycle = config['steps_per_cycle'] + t_mul = config.get('t_mul', 1.) + m_mul = config.get('m_mul', 1.) + alpha = config.get('alpha', 0.0) + warmup_steps = config.get('warmup_steps', 0.) + adjusted_step = jnp.maximum( + 0.0, (step - (warmup_steps + config.get('start_decay_step', 0.)))) + total_steps = config.get('total_steps', steps_per_cycle) + + # We make the cos equal and subtract warmup steps for each cycle. If + # there are fewer steps than warmup steps, cosine can be skipped. + steps_per_cycle = steps_per_cycle - int( + warmup_steps / (total_steps / steps_per_cycle)) + if steps_per_cycle > 0: + ratio *= cosine_decay_scheduler( + adjusted_step, + steps_per_cycle, + t_mul=t_mul, + m_mul=m_mul, + alpha=alpha) + elif name == 'linear_decay': + warmup_steps = config.get('warmup_steps', 0.) + total_steps = config.get('total_steps') + assert total_steps > warmup_steps, ( + 'With linear decay, total_steps should be higher than warmup_steps.' + ) + progress = jnp.maximum(0.0, (step - warmup_steps) / + float(total_steps - warmup_steps)) + ratio -= config.get('end_learning_rate', 0.) + ratio *= jnp.maximum(1.0 - progress, 0.0) + ratio += config.get('end_learning_rate', 0.) + + elif name == 'linear_cooldown': + adjusted_step = jnp.maximum(step, config.get('warmup_steps', 0.)) + ratio *= jnp.minimum(1., (config.total_steps - adjusted_step) / + config.cooldown_steps) + + else: + raise ValueError('Unknown factor %s.' % name) + + return jnp.asarray(ratio, dtype=jnp.float32) + + return lr_fn + + +lr_fn_dict = { + 'compound': compound_lr_scheduler, +} + + +def get_learning_rate_fn(config: ml_collections.ConfigDict): + """Looks up for the learning rate scheduler and return lr_fn. + + Args: + config: ConfigDict that has configuration ofthe learning rate function. + + Returns: + An learning rate or a function learning_rate(step): float -> + {'learning_rate': float}, the step-dependent lr. + + """ + if 'base_learning_rate' not in config.lr_configs: + raise ValueError( + '`base_learning_rate` has to be defined in the lr_config.') + if not config.lr_configs.base_learning_rate: + # raise ValueError( # raised for {0, False, None, [], (), {}} + # f'`base_learning_rate = {config.lr_configs.base_learning_rate}` is not ' + # 'allowed for training parameters. If your intention was to freeze ' + # 'parameters, use Scenic optax and `config.lr_configs = None` instead.') + pass + # Circumvent failing of config.lr_configs.base_learning_rate in {0, False, + # None, [], (), {}} here as a short-term solution. This case is for now + # handled in optax.make to handle edge cases. + if 'learning_rate_schedule' in config.lr_configs: + # A function that given the current step, returns the LR. + return lr_fn_dict[config.lr_configs['learning_rate_schedule']]( + config.lr_configs) + else: + # LR as a scalar value. + lr = jnp.asarray(config.lr_configs.base_learning_rate, dtype=jnp.float32) + return lambda step: lr diff --git a/scenic/train_lib/optax.py b/scenic/train_lib/optax.py new file mode 100644 index 0000000000000000000000000000000000000000..8066c7b94bf308cf7497baaff55d69b8c44ea637 --- /dev/null +++ b/scenic/train_lib/optax.py @@ -0,0 +1,370 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Optax utils for Scenic. + +This is a fork of +https://github.com/google-research/big_vision/blob/main/big_vision/optax.py. +""" + +import itertools +import numbers +import operator +import re +from typing import Any, Optional, Sequence, Tuple, Union, Callable, List + +from absl import logging +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np +import optax +from scenic.train_lib import lr_schedules +from scenic.train_lib import optimizers + + +def find_states(opt_state, cls): + leaves = jax.tree_util.tree_leaves( + opt_state, is_leaf=lambda node: isinstance(node, cls)) + return [leaf for leaf in leaves if isinstance(leaf, cls)] + + +def get_step(opt_state): + """Returns `ScaleByScheduleState.count` from `opt_state` as an integer.""" + counts = { + int(state.count) + for state in find_states(opt_state, optax.ScaleByScheduleState) + } + assert len(counts) == 1, f'Expected exactly 1 ScaleByScheduleState: {counts}' + return next(iter(counts)) + + +def _make_mask_trees( + params, + patterns_names_values: Union[Sequence[Tuple[str, str, Any]], + Sequence[Tuple[str, Any]]], + *, + allow_unmatched: bool = False, + log: Optional[str] = None): + """Wrapper around `make_mask_trees` that supports different input types.""" + if patterns_names_values: + if len(patterns_names_values[0]) == 3: + patterns, names, values = zip(*patterns_names_values) + else: + patterns, values = zip(*patterns_names_values) + names = [None] * len(values) + else: + patterns, names, values = [], [], [] + + masks = make_mask_trees( + params, + list(zip(patterns, names, values)), + allow_unmatched=allow_unmatched, + log=log, + ) + return masks, list(zip(names, values)) + + +def _split_frozen(masks, scheds): + """Computes `frozen_mask` and updates `masks` and `scheds`.""" + + def _is_none(sched): + """Helper to check if sched itself or (fn, base_lr) of sched are None.""" + if isinstance(sched, (tuple, list)): + _, fn_base_lr = sched # Check only the tuple of fn and base_lr. + return not any(fn_base_lr) # Only false if fn_base_lr = (None, None) + else: + return sched is None + + # Specifying `None` as a scheduler freezes params. + all_false = jax.tree_util.tree_map(lambda *bools: not any(bools), *masks) + frozen_masks = [ + mask for mask, sched in zip(masks, scheds) if _is_none(sched)] + frozen_mask = jax.tree_util.tree_map( + lambda *bools: any(bools), *frozen_masks, + all_false) # `all_false` is required when `frozen_masks==[]`. + masks, scheds = zip(*( + (mask, sched) for mask, sched in zip(masks, + scheds) if not _is_none(sched))) + return frozen_mask, masks, scheds + + +def make_mask_trees( + tree, + patterns_names: Sequence[Tuple[str, Optional[str], float]], + *, + allow_unmatched: bool = False, + log: Optional[str] = None, +): + """Returns a boolean mask tree for every pattern (only first match).""" + + patterns, _, _ = zip(*patterns_names) + compiled_patterns = list(map(re.compile, patterns)) + + def matchfirst(_, name): + matches = [bool(pattern.fullmatch(name)) for pattern in compiled_patterns] + + matched = sum(map(int, matches)) + matched_patterns = [patterns_names[i] for i, m in enumerate(matches) if m] + if matched > 1: + raise ValueError( + f'{name} matched by multiple patterns: {matched_patterns}') + + if matched == 0 and not allow_unmatched: + raise ValueError(f'{name} was *not* matched by a single pattern!') + + if log is not None: + if any(matches): + logging.info('%s: %s - matched by %s', log, name, + patterns_names[matches.index(True)]) + else: + logging.info('%s: %s - not matched by any patterns', log, name) + return np.array(matches) + + multimask = optimizers.tree_map_with_names_values(matchfirst, tree) + return [ + jax.tree_util.tree_map(lambda matches, i=idx: matches[i], multimask) + for idx in range(len(patterns)) + ] + + +def replace_frozen(schedule, pytree, replacement, log: Optional[str] = None): + """Replaces values matching frozen params in `pytree` with `replacement`.""" + if schedule is None: + return pytree + schedule = [(cfg.re, cfg.lr_configs) for name, cfg in schedule.items()] + + masks, scheds = _make_mask_trees(pytree, schedule, log=log) + frozen_mask, _, _ = _split_frozen(masks, [value for _, value in scheds]) + return jax.tree_util.tree_map( + lambda v, f: replacement if f else v, pytree, frozen_mask) + + +def make_schedule( + schedule: Optional[ml_collections.ConfigDict] = None, + get_learning_rate_fn: Callable[ + [ml_collections.ConfigDict], + optax.ScalarOrSchedule] = lr_schedules.get_learning_rate_fn, +) -> List[Tuple[str, str, Tuple[optax.ScalarOrSchedule, float]]]: + """Creates a schedule dictionary compatible with the `make` function.""" + # Global schedule. No schedule means frozen. + if schedule is None: + schedule = ml_collections.ConfigDict( + {'all': ml_collections.ConfigDict({'re': '(.*)', 'lr_configs': None})}) + schedule = [(cfg.re, name, cfg.lr_configs) for name, cfg in schedule.items()] + + # Create actual schedules funtions. + def create_schedule(lr_configs): + if lr_configs is None: + return None, None # Parameters are frozen + fn = get_learning_rate_fn( + ml_collections.ConfigDict({'lr_configs': lr_configs})) + # Base LR is used for decoupling WD from LR schedules. + base_lr = lr_configs.get('base_learning_rate', 1.0) + return fn, base_lr + + schedule = [(re, name, create_schedule(lr_configs)) + for re, name, lr_configs in schedule] + return schedule + + +def make(config: ml_collections.ConfigDict, + schedule: Sequence[ + Tuple[str, str, Tuple[optax.ScalarOrSchedule, float]]], + params): + """Returns gradient transform and learning rate functions. + + Args: + config: Optimizer config. + schedule: Learning rate schedules as tuple of regexp, name, learning rate + schedule function and base learning rate (for WD decoupling). + params: Model parameters. + """ + if not config.get('per_example_clipping'): + # Collect all base_lrs and transform to bool. Each element of schedule fol- + # lows the structure (re, name, (fn, base_lr)) [see above]. + base_lrs = [fn_base_lr[1] for _, _, fn_base_lr in schedule] + if any([base_lr == 0 for base_lr in base_lrs]): + raise ValueError( # raised if base_lr = 0 + f'`base_learning_rate` contains unsupported values {base_lrs}. If ' + 'your intention was to freeze parameters, use Scenic optax and ' + '`config.lr_configs = None` instead.') + masks, scheds = _make_mask_trees(params, schedule, log='schedule') + frozen_mask, masks, scheds = _split_frozen(masks, scheds) + not_frozen_mask = jax.tree_util.tree_map(operator.not_, frozen_mask) + schedule_fns, schedule_base_lr = zip( + *[fn_base for _, fn_base in (scheds or [])]) + schedule_txs = [ + optax.masked(optax.scale_by_schedule(schedule_fn), mask) + for schedule_fn, mask in zip(schedule_fns, masks) + ] + [ + # Removes weight decay updates. Note that weight decay already has an + # independent mask (which cannot be combined easily with a second mask), + # so instead we multiply updates for frozen params with zero. + optax.masked(optax.set_to_zero(), frozen_mask) + ] + + # Gradient clipping. + grad_clip_norm_tx = [] + if config.get('max_grad_norm'): + if not config.get('per_example_clipping'): + grad_clip_norm_tx = [ + optax.masked( + optax.clip_by_global_norm(config.max_grad_norm), + not_frozen_mask)] + elif 'optax_grad_pmean' in config: + if not config.optax_grad_pmean: + raise ValueError('Per-example gradient aggregateion outside of Optax ' + 'is not supported.') + + # Assume default pmean axis. + axis_name = 'batch' + if isinstance(config.optax_grad_pmean, str): + axis_name = config.optax_grad_pmean + + # Per-example clipping is implemented as differentially private gradients + # with *zero* noise. + grad_clip_norm_tx = [ + optax.masked( + optax.contrib.differentially_private_aggregate( + config.max_grad_norm, 0.0, 0), + not_frozen_mask), + aggregate_gradients_pmean(axis_name=axis_name)] + elif 'optax_grad_mean' in config: + if not config.optax_grad_mean: + raise ValueError('Per-example gradient aggregation outside of Optax ' + 'is not supported.') + grad_clip_norm_tx = [ + optax.masked( + optax.differentially_private_aggregate( + config.max_grad_norm, 0.0, 0), + not_frozen_mask),] + else: + raise ValueError( + 'When using per-example clipping, ' + 'optimizer.optax_grad_pmean or optimizer.optax_grad_mean must be set.' + ) + else: + grad_clip_norm_tx = [] + + # Optimizer updates. + tx_func = operator.attrgetter(config.optax_name)(optax) + opt_txs = [optax.masked( + tx_func(**config.get('optax_configs', {})), not_frozen_mask)] + + # Weight decay. Defaults to 0.0. + # Weight decay is not gradient-based but instead uses "params side-input". + # Hence, weight decay is additive and independent of previous gradient-based + # updates. + assert config.get('weight_decay_decouple', True), ( + 'Coupled weight decay not supported anymore.') + decay_rules = config.get('weight_decay', []) or [] + if isinstance(decay_rules, numbers.Number): + decay_rules = [('.*kernel.*', decay_rules)] + + if decay_rules: + decay_masks, mults = _make_mask_trees( + params, decay_rules, + allow_unmatched=True, log='config.optimizer.weight_decay') + mults = [mult for _, mult in mults] # Remove dummy "name" from the tuples. + + weight_decay_txs = [] + # Create decoupled WD masks by enumerating all schedule x decay mask + # combinations. + for (mult, decay_mask), (mask, base_lr) in itertools.product( + zip(mults, decay_masks), zip(masks, schedule_base_lr)): + weight_decay_txs.append( + optax.add_decayed_weights( + mult / base_lr if base_lr else 0.0, # Decouple WD from LR. + jax.tree_util.tree_map(lambda a, b: a and b, decay_mask, mask))) + else: + weight_decay_txs = [] + + # Combine gradient updates and learning rate schedules. + opt = optax.chain( + *grad_clip_norm_tx, + *opt_txs, + *weight_decay_txs, + *schedule_txs, + optax.scale(-1.0)) + return opt, schedule_fns + + +def aggregate_gradients_pmean( + axis_name: str = 'batch', +) -> optax.GradientTransformation: + """Aggregates gradients using JAX's pmean. + + Args: + axis_name: Name of the axis for pmean aggregation. + + Returns: + A `GradientTransformation`. + """ + + def init_fn(params): + del params + return None + + def update_fn(updates, state, params=None): + del params, state + return jax.lax.pmean(updates, axis_name=axis_name), None + + return optax.GradientTransformation(init_fn, update_fn) + +################# Scenic optimizers ############################## +# This is following the BV codebase pattern for defining a custom optimizer. +# A dummy object to allow for foo.bar access syntax, see +# https://stackoverflow.com/a/19476841/2366315 +optax.scenic = type('', (), {})() + + +def scale_by_adafactor(min_dim_size_to_factor=32, + decay_rate=0.8, decay_offset=0, + beta2_cap=0.999, + clipping_threshold=None, + momentum=0.9, dtype_momentum=jnp.bfloat16, + eps=1e-30): + """The BigVision variant of Adafactor optimizer.""" + + def _decay_rate_pow(i, exponent): + """Second-order moment decay schedule.""" + t = jnp.array(i, jnp.float32) + 1.0 + return jnp.minimum(beta2_cap, 1.0 - t**(-exponent)) + + scale_by_rms = optax.scale_by_factored_rms( + factored=True, + decay_rate=decay_rate, + step_offset=decay_offset, + min_dim_size_to_factor=min_dim_size_to_factor, + epsilon=eps, + decay_rate_fn=_decay_rate_pow) + + clip = (optax.clip_by_block_rms(clipping_threshold) if clipping_threshold + else optax.identity()) + + mom = (optax.ema(momentum, debias=False, accumulator_dtype=dtype_momentum) + if momentum else optax.identity()) + + return optax.chain(scale_by_rms, clip, mom) + +optax.scenic.scale_by_adafactor = scale_by_adafactor # pytype: disable=module-attr + + +def momentum_hp(momentum=0.9, dtype=jnp.bfloat16, nesterov=False): + """SGD-Momentum with half-precision accumulator.""" + return optax.trace(decay=momentum, accumulator_dtype=dtype, nesterov=nesterov) + + +optax.scenic.momentum_hp = momentum_hp # pytype: disable=module-attr diff --git a/scenic/train_lib/optimizers.py b/scenic/train_lib/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..46d8ebe6bf931adc991bad15ca42085be9bbfc72 --- /dev/null +++ b/scenic/train_lib/optimizers.py @@ -0,0 +1,345 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines different optimizers with optax. + +Based on +https://github.com/google-research/big_vision/blob/main/big_vision/optax.py +and +https://github.com/google-research/big_vision/blob/main/big_vision/utils.py +""" +import copy +import dataclasses +import operator +import re +from typing import Any, Callable, Generator, List, Optional, Tuple, Union + +from absl import logging +import flax +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np +import optax + + +# JAX team is working type checking for pytrees: +# https://github.com/jax-ml/jax/issues/3340 +PyTree = Any +ScalarOrSchedule = Union[float, optax.Schedule] + + +def get_optimizer( + optimizer_config: ml_collections.ConfigDict, + learning_rate_fn: ScalarOrSchedule, + params: Optional[PyTree] = None, +) -> optax.GradientTransformation: + """Constructs the optimizer from the given configuration. + + The function is constructed in such a way that it will throw errors if + fields in the optimizer_config are misspelled. + + Args: + optimizer_config: Configuration specific to the optimizer. The config + can contain the following fields: + - optimizer: name of the optax optimizer. + - **kwargs: fields specific to the optax optimizer. + - weight_decay: value of the weight decay. + - skip_scale_and_bias_regularization: if True, do not apply weight + decay to scale and biases. + - grad_clip: configdict with settings of gradient clipping. + - freeze_params_reg_exp: regular expression to define which weights + will be frozen during training. This uses re.search, so 'conv' would + match any parameter which has 'conv' somewhere in its name such as + 'cnn/first_conv_layer/bias'. Note that only parameters will be frozen, + which means batch_norm remains unaffected. + learning_rate_fn: Learning rate schedule. + params: Parameters pytree, used when we want to skip weight decay on bias + and scale parameters. Also used for freezing weights. + + Returns: + An optax GradientTransformation, this consists of a pair of pure functions + implementing a gradient transformation. + """ + # Avoid modifying original config and allow alteration. + config = copy.deepcopy(optimizer_config).unlock() + + # Skip weight decay for BatchNorm scale or for the bias parameters. + weight_decay_mask = None + if config.get('skip_scale_and_bias_regularization') is not None: + if (config.skip_scale_and_bias_regularization and + config.get('weight_decay', 0)): + if params is None: + raise ValueError('params must be given to obtain weight_decay_mask.') + weight_decay_mask = jax.tree_util.tree_map(lambda x: x.ndim != 1, params) + if 'skip_scale_and_bias_regularization' in config: + del config.skip_scale_and_bias_regularization + + optim_ops = [] + # Add weight decay for sgd (possibly with momentum and nesterov). + if config.optimizer == 'sgd' and 'weight_decay' in config: + if config.weight_decay: + optim_ops.append( + optax.add_decayed_weights(config.weight_decay, weight_decay_mask)) + del config.weight_decay + + if weight_decay_mask and config.optimizer in {'adamw', 'lamb', 'adamaxw'}: + config.mask = weight_decay_mask + elif weight_decay_mask and config.optimizer in {'adafactor', 'lars'}: + config.weight_decay_mask = weight_decay_mask + + # Add gradient clipping before optimizer operations. + if config.get('grad_clip') is not None: + grad_clip_config = config.grad_clip + clip_method = grad_clip_config.get('clip_method', None) + clip_value = grad_clip_config.get('clip_value', None) + if clip_method is not None and clip_value is not None: + if clip_method == 'clip_by_global_norm': + optim_ops.append(optax.clip_by_global_norm(clip_value)) + elif clip_method == 'adaptive_grad_clip': + optim_ops.append(optax.adaptive_grad_clip(clip_value)) + elif clip_method == 'clip': + optim_ops.append(optax.clip(clip_value)) + elif clip_method == 'clip_by_block_rms': + optim_ops.append(optax.clip_by_block_rms(clip_value)) + else: + logging.info('%s is not supported', clip_method) + if 'grad_clip' in config: + del config.grad_clip + + # Remove freeze_params_reg_exp here. This should be the last operation to + # ensure parameters are truly frozen. But this field needs to be removed + # because all remaining fields in the config are given to the optimizer. + freeze_mask = None + unfreeze_mask = None + if config.get('freeze_params_reg_exp') is not None: + if params is None: + raise ValueError('params must be given to obtain frozen parameters.') + freeze_mask = tree_mask(params, config.freeze_params_reg_exp) + unfreeze_mask = jax.tree_util.tree_map(lambda x: not x, freeze_mask) + del config.freeze_params_reg_exp + + num_params_unfrozen = jax.tree_util.tree_reduce(operator.add, unfreeze_mask) + if not num_params_unfrozen: + raise ValueError('freeze_params_reg_exp matched all parameters in ' + 'the model, which prevents any training from happening.') + if 'freeze_params_reg_exp' in config: + del config.freeze_params_reg_exp + + # Call the optax optimizer with exact arguments as in the config. + # This throws an error when the config has (spelling) mistakes. + optimizer_fn = getattr(optax, config.optimizer) + del config.optimizer + optax_optimizer = optimizer_fn(learning_rate=learning_rate_fn, **config) + # Apply to unfrozen weights to prevent change in optimizer state. + # In turn, this prevents unnecessary gradient calculations. + if unfreeze_mask: + optax_optimizer = optax.masked(optax_optimizer, unfreeze_mask) + optim_ops.append(optax_optimizer) + + # Freezing params should be the final operation in the optax chain to ensure + # that freezing overrides everything including weight decay. + if freeze_mask: + optim_ops.append(optax.masked(optax.set_to_zero(), freeze_mask)) + + # Log variables which will change during training. + freeze_mask_flat = flax.traverse_util.flatten_dict(freeze_mask, sep='/') + logging.info('Freeze mask set. Training only on the following params:') + for param_name, value in freeze_mask_flat.items(): + if not value: + logging.info('--> %s', param_name) + + return optax.chain(*optim_ops) + + +def tree_mask(params: PyTree, reg_exp: str): + """Returns a tree mask based on regular expression for use with optax.masked. + + Args: + params: PyTree with parameters. + reg_exp: Regular expression. Will be compiled and used together with + re.search. + """ + pattern = re.compile(reg_exp) + + def match_var_name(_, name): + if pattern.search(name): + return True + return False + + return tree_map_with_names_values(match_var_name, params) + + +def get_optax_optimizer_config( + config: ml_collections.ConfigDict) -> ml_collections.ConfigDict: + """Obtain optimizer from main config.""" + optimizer_config = config.get('optimizer_configs', + ml_collections.ConfigDict()) + + # New-style config: all optimizer-related fields are in optimizer_configs. + if 'optimizer' in optimizer_config: + if 'optimizer' in config: + raise ValueError( + 'Both config.optimizer and config.optimizer_configs.optimizer are ' + 'defined. Define it only once to avoid possible contradictions. ' + 'The preferred location is in config.optimizer_configs.optimizer') + return optimizer_config + + # Backwards compatibility: copy optimizer field into the optimizer config. + optimizer_config = copy.deepcopy(optimizer_config).unlock() + if 'optimizer' in config: + optimizer_config.optimizer = config.optimizer + + # The old optimizers have adam with weight decay. However, in optax this is + # done using the adamw optimizer. + if config.optimizer == 'adam' and 'weight_decay' in optimizer_config: + optimizer_config.optimizer = 'adamw' + + if config.optimizer == 'momentum': + optimizer_config.optimizer = 'sgd' + if 'momentum' not in optimizer_config: + # flax.optim had a default momentum value of 0.9. + # optax.sgd has a default momentum of 0. + logging.warning( + 'flax.optim had a default momentum value of 0.9. optax has a ' + 'default value of 0. As a momentum value was not specified, ' + 'adding momentum=0.9 to optimizer config.') + optimizer_config.momentum = 0.9 + + if config.optimizer == 'nesterov': + optimizer_config.optimizer = 'sgd' + optimizer_config.nesterov = True + + if 'skip_scale_and_bias_regularization' in config: + optimizer_config.skip_scale_and_bias_regularization = ( + config.skip_scale_and_bias_regularization) + + optimizer_config = _scenic_optimizer_args_to_optax_args(optimizer_config) + + if 'grad_clip_configs' in config: + optimizer_config.grad_clip = config.grad_clip_configs + + optimizer_config.lock() + logging.info('Optimizer config after backwards compatibility operations:\n%s', + optimizer_config) + return optimizer_config + + +def _scenic_optimizer_args_to_optax_args( + config: ml_collections.ConfigDict) -> ml_collections.ConfigDict: + """Transform original scenic arguments to optax arguments.""" + if 'beta1' in config: + config.b1 = config.beta1 + del config.beta1 + if 'beta2' in config: + config.b2 = config.beta2 + del config.beta2 + if 'epsilon' in config: + config.eps = config.epsilon + del config.epsilon + return config + + +def _traverse_with_names( + tree: PyTree) -> Generator[Tuple[str, PyTree], None, None]: + """Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val).""" + if dataclasses.is_dataclass(tree): + tree = flax.serialization.to_state_dict(tree) + if isinstance(tree, (dict, flax.core.frozen_dict.FrozenDict)): + keys = sorted(tree.keys()) + for key in keys: + for path, v in _traverse_with_names(tree[key]): + yield (key + '/' + path).rstrip('/'), v + else: + yield '', tree + + +def tree_flatten_with_names( + tree: PyTree) -> Tuple[List[Tuple[str, jnp.ndarray]], PyTree]: + """Populates tree_flatten with leaf names. + + This function populates output of tree_flatten with leaf names, using a + custom traversal that produces names is provided. The custom traversal does + NOT have to traverse tree in the same order as jax, as we take care of + automatically aligning jax' and custom traversals. + + Args: + tree: python tree. + + Returns: + A list of values with names: [(name, value), ...] + """ + vals, tree_def = jax.tree_util.tree_flatten(tree) + + # "Fake" token tree that is use to track jax internal tree traversal and + # adjust our custom tree traversal to be compatible with it. + tokens = range(len(vals)) + token_tree = tree_def.unflatten(tokens) + val_names, perm = zip(*_traverse_with_names(token_tree)) + inv_perm = np.argsort(perm) + + # Custom traverasal should visit the same number of leaves. + assert len(val_names) == len(vals) + + return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def + + +def tree_map_with_names( + f: Callable[[jnp.ndarray], jnp.ndarray], + param_tree: PyTree, + match_name_fn: Callable[[str], bool] = lambda name: True) -> PyTree: + """Like jax.tree_util.tree_map but with a filter on the leaf path name. + + Args: + f: The function to be applied to each parameter in `param_tree`. Takes value + as argument. + param_tree: The tree of parameters `f` should be applied to. + match_name_fn: This function is called with each tree leaf's path name, + which has a path-like format ("a/b/c"), and decides whether `f` should be + applied to that leaf or the leaf should be kept as-is. + + Returns: + A tree identical in structure to `param_tree` but with the leaves the + result of calling `f` on them in the cases where `match_name_fn` returns + True for that leaf's path name. + """ + names_and_vals, tree_def = tree_flatten_with_names(param_tree) + vals = [f(v) if match_name_fn(name) else v for name, v in names_and_vals] + return tree_def.unflatten(vals) + + +def tree_map_with_names_values( + f: Callable[[jnp.ndarray, str], jnp.ndarray], + param_tree: PyTree, + match_name_fn: Callable[[str], bool] = lambda name: True) -> PyTree: + """Like tree_map_with_names but with `f` having access to values *and* names. + + Args: + f: The function to be applied to each parameter in `param_tree`. Takes value + and name as arguments. + param_tree: The tree of parameters `f` should be applied to. + match_name_fn: This function is called with each tree leaf's path name, + which has a path-like format ("a/b/c"), and decides whether `f` should be + applied to that leaf or the leaf should be kept as-is. + + Returns: + A tree identical in structure to `param_tree` but with the leaves the + result of calling `f` on them in the cases where `match_name_fn` returns + True for that leaf's path name. + """ + names_and_vals, tree_def = tree_flatten_with_names(param_tree) + vals = [ + f(v, name) if match_name_fn(name) else v for name, v in names_and_vals + ] + return tree_def.unflatten(vals) diff --git a/scenic/train_lib/pretrain_utils.py b/scenic/train_lib/pretrain_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2610ed498888a1f92713006547104849138b58e0 --- /dev/null +++ b/scenic/train_lib/pretrain_utils.py @@ -0,0 +1,353 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for using pretrained models.""" + +from collections import abc +import os +import re +from typing import Any, Dict, Mapping, List, Optional, Union + +from absl import logging +from big_vision import utils +import flax +from flax.training import checkpoints +import numpy as np + +from scenic.train_lib import train_utils +from tensorflow.io import gfile + +# JAX team is working on type annotation for pytree: +# https://github.com/jax-ml/jax/issues/1555 +PyTree = Union[Mapping[str, Mapping], Any] + + +def _replace_dict(model: PyTree, + restored: PyTree, + ckpt_prefix_path: Optional[List[str]] = None, + model_prefix_path: Optional[List[str]] = None, + name_mapping: Optional[Mapping[str, str]] = None, + skip_regex: Optional[str] = None) -> PyTree: + """Replaces values in model dictionary with restored ones from checkpoint.""" + name_mapping = name_mapping or {} + + model = flax.core.unfreeze(model) # pytype: disable=wrong-arg-types + restored = flax.core.unfreeze(restored) # pytype: disable=wrong-arg-types + + if ckpt_prefix_path: + for p in ckpt_prefix_path: + restored = restored[p] + + if model_prefix_path: + for p in reversed(model_prefix_path): + restored = {p: restored} + + # Flatten nested parameters to a dict of str -> tensor. Keys are tuples + # from the path in the nested dictionary to the specific tensor. E.g., + # {'a1': {'b1': t1, 'b2': t2}, 'a2': t3} + # -> {('a1', 'b1'): t1, ('a1', 'b2'): t2, ('a2',): t3}. + restored_flat = flax.traverse_util.flatten_dict( + dict(restored), keep_empty_nodes=True) + model_flat = flax.traverse_util.flatten_dict( + dict(model), keep_empty_nodes=True) + + for m_key, m_params in restored_flat.items(): + # pytype: disable=attribute-error + for name, to_replace in name_mapping.items(): + m_key = tuple(to_replace if k == name else k for k in m_key) + # pytype: enable=attribute-error + m_key_str = '/'.join(m_key) + if m_key not in model_flat: + logging.warning('%s in checkpoint doesn\'t exist in model. Skip.', + m_key_str) + continue + if skip_regex and re.findall(skip_regex, m_key_str): + logging.info('Skip loading parameter %s.', m_key_str) + continue + logging.info('Loading %s from checkpoint into model', m_key_str) + model_flat[m_key] = m_params + + return flax.core.freeze(flax.traverse_util.unflatten_dict(model_flat)) + + +def init_from_pretrain_state( + train_state: train_utils.TrainState, + pretrain_state: Union[PyTree, train_utils.TrainState], + ckpt_prefix_path: Optional[List[str]] = None, + model_prefix_path: Optional[List[str]] = None, + name_mapping: Optional[Mapping[str, str]] = None, + skip_regex: Optional[str] = None) -> train_utils.TrainState: + """Updates the train_state with data from pretrain_state. + + Args: + train_state: A raw TrainState for the model. + pretrain_state: A TrainState that is loaded with parameters/state of + a pretrained model. + ckpt_prefix_path: Prefix to restored model parameters. + model_prefix_path: Prefix to the parameters to replace in the subtree model. + name_mapping: Mapping from parameter names of checkpoint to this model. + skip_regex: If there is a parameter whose parent keys match the regex, + the parameter will not be replaced from pretrain_state. + + Returns: + Updated train_state. + """ + name_mapping = name_mapping or {} + restored_params = pretrain_state['params'] + restored_model_state = pretrain_state['model_state'] + model_params = _replace_dict(train_state.params, restored_params, + ckpt_prefix_path, model_prefix_path, + name_mapping, skip_regex) + train_state = train_state.replace(params=model_params) + # TODO(scenic): Add support for optionally restoring optimizer state. + if (restored_model_state is not None and + train_state.model_state is not None and train_state.model_state): + if model_prefix_path: + # Insert model prefix after 'batch_stats'. + model_prefix_path = ['batch_stats'] + model_prefix_path + if 'batch_stats' in restored_model_state: + ckpt_prefix_path = ckpt_prefix_path or [] + ckpt_prefix_path = ['batch_stats'] + ckpt_prefix_path + elif 'batch_stats' not in restored_model_state: # Backward compatibility. + model_prefix_path = ['batch_stats'] + if ckpt_prefix_path and ckpt_prefix_path[0] != 'batch_stats': + ckpt_prefix_path = ['batch_stats'] + ckpt_prefix_path + model_state = _replace_dict(train_state.model_state, + restored_model_state, + ckpt_prefix_path, + model_prefix_path, + name_mapping, + skip_regex) + train_state = train_state.replace( # pytype: disable=attribute-error + model_state=model_state) + return train_state + + +def restore_pretrained_checkpoint( + checkpoint_path: str, + train_state: Optional[train_utils.TrainState] = None, + assert_exist: bool = False, + step: Optional[int] = None) -> train_utils.TrainState: + """Restores the last checkpoint. + + First restores the checkpoint, which is an instance of TrainState that holds + the state of training. This function also take care converting pre-Linen + checkpoints. + + Args: + checkpoint_path: Directory for saving the checkpoint. + train_state: An instance of TrainState that holds the state of training. + assert_exist: Assert that there is at least one checkpoint exists in the + given path. + step: Step number to load or None to load latest. If specified, + checkpoint_path must be a directory. + + Returns: + Training state and an int which is the current step. + """ + if assert_exist: + glob_path = os.path.join(checkpoint_path, 'checkpoint_*') + if not gfile.glob(glob_path): + raise ValueError('No checkpoint for the pretrained model is found in: ' + f'{checkpoint_path}') + restored_train_state = checkpoints.restore_checkpoint(checkpoint_path, None, + step) + if restored_train_state is None: + raise ValueError('No checkpoint for the pretrained model is found in: ' + f'{checkpoint_path}') + if 'params' in restored_train_state: + # restored_train_state was trained using optax + restored_params = flax.core.freeze(restored_train_state['params']) + else: + # restored_train_state was trained using flax.optim. Note that this does + # not convert the naming of pre-Linen checkpoints. + restored_params = restored_train_state['optimizer']['target'] + if 'params' in restored_params: # Backward compatibility. + restored_params = restored_params['params'] + restored_params = dict(checkpoints.convert_pre_linen(restored_params)) + restored_params = flax.core.freeze(restored_params) + + restored_model_state = ( + None if restored_train_state['model_state'] is None else + flax.core.freeze(restored_train_state['model_state']) + ) + + if not train_state: + train_state = train_utils.TrainState() + params = restored_params + else: + # Inspect and compare the parameters of the model with the init-model. + params = inspect_params( + expected_params=train_state.params, + restored_params=restored_params, + fail_if_extra=False, + fail_if_missing=False, + fail_if_shapes_mismatch=False) + train_state = train_state.replace( + # Inspect and compare the parameters of the model with the init-model. + params=params, + model_state=restored_model_state, + global_step=int(restored_train_state['global_step']), + rng=restored_train_state['rng'], + metadata=restored_train_state.get('metadata', None)) + return train_state + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def inspect_params(*, + expected_params: PyTree, + restored_params: PyTree, + fail_if_extra: bool = True, + fail_if_missing: bool = True, + fail_if_shapes_mismatch: bool = False) -> PyTree: + """Inspects whether the params are consistent with the expected keys. + + Based on + https://github.com/google-research/big_vision/blob/main/big_vision/model/common.py. + """ + + def _flatten_params(d, parent_key='', sep='/'): + """Flattens a dictionary, keeping empty leaves.""" + items = [] + for k, v in d.items(): + path = parent_key + sep + k if parent_key else k + if isinstance(v, abc.MutableMapping): + items.extend(_flatten_params(v, path, sep=sep).items()) + else: + items.append((path, v)) + # Keeps the empty dict if it was set explicitly. + if parent_key and not d: + items.append((parent_key, {})) + return dict(items) + + expected_flat = _flatten_params(flax.core.unfreeze(expected_params)) + restored_flat = _flatten_params(flax.core.unfreeze(restored_params)) + missing_keys = expected_flat.keys() - restored_flat.keys() + extra_keys = restored_flat.keys() - expected_flat.keys() + + is_shape_mismatch = False + for key in restored_flat: + if key in expected_flat: + restored_shape = None + expected_shape = None + # Handle empty nodes (without trainable params) + if not isinstance(restored_flat[key], dict): + restored_shape = restored_flat[key].shape + if not isinstance(expected_flat[key], dict): + expected_shape = expected_flat[key].shape + + if restored_shape != expected_shape: + is_shape_mismatch = True + logging.warning('Key: %s. Expected shape: %s. Restored shape: %s', key, + expected_flat[key].shape, restored_flat[key].shape) + + # Adds back empty dict explicitly, to support layers without weights. + # Context: FLAX ignores empty dict during serialization. + empty_keys = set() + for k in missing_keys: + if isinstance(expected_flat[k], dict) and not expected_flat[k]: + restored_params[k] = {} # pytype: disable=unsupported-operands + empty_keys.add(k) + missing_keys -= empty_keys + + if empty_keys: + logging.warning('Inspect recovered empty keys:\n%s', empty_keys) + + logging.info('Inspect missing keys:\n%s', missing_keys) + logging.info('Inspect extra keys:\n%s', extra_keys) + + if fail_if_shapes_mismatch and is_shape_mismatch: + raise ValueError('Shape mismatch between restored and target model') + + if (missing_keys and fail_if_missing) or (extra_keys and fail_if_extra): + raise ValueError( + f'Missing params from checkpoint: {missing_keys}.\n' + f'Extra params in checkpoint: {extra_keys}.\n' + f'Restored params from checkpoint: {restored_flat.keys()}.\n' + f'Expected params from code: {expected_flat.keys()}.') + return restored_params +# pylint: enable=g-doc-args,g-doc-return-or-yield + + +def convert_big_vision_to_scenic_checkpoint( + checkpoint_path: str, + train_state: Optional[train_utils.TrainState] = None, + convert_to_linen: bool = True) -> train_utils.TrainState: + """Converts a big_vision checkpoint to a scenic train state. + + The model weights, global step and accumulated train time are extracted. + Optimizer state, such as the momentum, is not extracted. + + Args: + checkpoint_path: Path to big_vision checkpoint. + train_state: A Scenic TrainState object. + convert_to_linen: Whether to convert to Linen format. + + Returns: + restored_train_state: Scenic train state with model weights, global step + and accumulated training time. + """ + + def unflatten_dict(flattened: Dict[str, Any], + separator: str = '/', + leaf_idx: int = -1) -> Dict[str, Any]: + unflattened = {} + for k, v in flattened.items(): + subtree = unflattened + if leaf_idx != 0: + path = k.split(separator)[:leaf_idx] + else: + path = k.split(separator) + for k2 in path[:-1]: + if k2 not in subtree: + subtree[k2] = {} + subtree = subtree[k2] + subtree[path[-1]] = v + return unflattened + + logging.info('Loading big_vision checkpoint from %s', checkpoint_path) + if '.bv' in checkpoint_path: + checkpoint_data = utils.load_checkpoint_ts(checkpoint_path) + else: + checkpoint_data = np.load(gfile.GFile(checkpoint_path, 'rb')) + tree = unflatten_dict(checkpoint_data, separator='/', leaf_idx=0) + + restored_params = ( + tree['opt']['target'] + if 'target' in tree.get('opt', {}) + else tree['params'] + ) + if convert_to_linen: + restored_params = checkpoints.convert_pre_linen(restored_params) + restored_params = dict(restored_params) + if train_state: + restored_params = inspect_params( + expected_params=train_state.params, + restored_params=restored_params, + fail_if_extra=False, + fail_if_missing=False, + fail_if_shapes_mismatch=False) + else: + train_state = train_utils.TrainState() + + # pytype: disable=wrong-arg-types + restored_train_state = train_state.replace( # pytype: disable=attribute-error + global_step=int( + tree['opt']['state']['step'] if 'state' in tree.get('opt', {}) else 0 + ), + params=restored_params, + ) + # pytype: enable=wrong-arg-types + + return restored_train_state diff --git a/scenic/train_lib/tests/__init__.py b/scenic/train_lib/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/tests/test_classification_trainer.py b/scenic/train_lib/tests/test_classification_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/tests/test_lr_schedules.py b/scenic/train_lib/tests/test_lr_schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/tests/test_optax.py b/scenic/train_lib/tests/test_optax.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/tests/test_optimizers.py b/scenic/train_lib/tests/test_optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/train_utils.py b/scenic/train_lib/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2db22feb8d18b05337dc7fe73a06b3bba73ede63 --- /dev/null +++ b/scenic/train_lib/train_utils.py @@ -0,0 +1,1295 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for Training.""" + +import collections.abc as collections +import copy +import functools +import os +import re +import time +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union + +from absl import logging +from clu import metric_writers +import flax +from flax import jax_utils +from flax import struct +import flax.linen as nn +from flax.training import checkpoints +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np +import optax +from scenic.common_lib import debug_utils +from scenic.dataset_lib import dataset_utils +from scenic.dataset_lib import datasets +from scenic.train_lib import optimizers +from tensorflow.io import gfile + +# JAX team is working on type annotation for pytree: +# https://github.com/jax-ml/jax/issues/1555 +PyTree = Any +PRNGKey = jnp.ndarray + + +@struct.dataclass +class TrainState: + """Dataclass to keep track of state of training. + + The state of training is structured as a struct.dataclass, which enables + instances of this class to be passed into jax transformations like tree_map + and pmap. + """ + + tx: Optional[optax.GradientTransformation] = struct.field( + default=None, pytree_node=False + ) + opt_state: Optional[optax.OptState] = None + params: Optional[Any] = struct.field(default_factory=dict) + global_step: Optional[int] = 0 + model_state: Optional[Any] = struct.field(default_factory=dict) + rng: Optional[jnp.ndarray] = None + metadata: Optional[Dict[str, Any]] = None + # NOTE: When using the raw TrainState as the target for checkpoint restoration + # in Flax, you should provide the pytree structure, otherwise it might just + # silenty ignore restoring the checkpoint subtree if you use with an empty + # dict when setting `allow_partial_mpa_restoration=True` and if you set it + # to None (e.g., for `metadata`` above), Flax replaces it with a state dict. + + def __getitem__(self, item): + """Make TrainState a subscriptable object.""" + return getattr(self, item) + + def get(self, keyname: str, default: Optional[Any] = None) -> Any: + """Return the value for key if it exists otherwise the default.""" + try: + return self[keyname] + except KeyError: + return default + + +def expand_dims_for_specs(xs, specs): + return jax.tree.map( + lambda s, x: jax.tree.map( + functools.partial(jnp.expand_dims, axis=tuple(range(len(s)))), + x, + ), + specs, + xs, + ) + + +def squeeze_for_specs(xs, specs): + return jax.tree.map( + lambda s, x: jax.tree.map( + functools.partial(jnp.squeeze, axis=tuple(range(len(s)))), + x, + ), + specs, + xs, + ) + + +def initialize_model( + *, + model_def: nn.Module, + input_spec: Sequence[ + Union[Tuple[Tuple[int, ...], jnp.dtype], Tuple[int, ...], None] + ], + config: ml_collections.ConfigDict, + rngs: Union[jnp.ndarray, Mapping[str, jnp.ndarray]], + train: Optional[bool] = False, + **model_kwargs, +) -> Tuple[PyTree, PyTree, int, Optional[float]]: + """Initializes parameters and model state. + + Args: + model_def: Definition of a model. + input_spec: An iterable of (shape, dtype) pairs specifying the shape and + dtype of the inputs. If unspecified the dtype is float32. + config: Configurations of the initialization. + rngs: Jax rng keys. + train: If the scenic model should be initialized in the train mode. + **model_kwargs: Kwargs passed to flax model initialization. + + Returns: + Initial params, Init model_state, number of trainable_params, and gflops. + """ + batch_size = ( + (config.batch_size // jax.device_count()) + if config.get('batch_size') + else None + ) + dummy_input = [] + for spec in input_spec: + if spec is not None: + in_st = debug_utils.input_spec_to_jax_shape_dtype_struct( + spec, batch_size=batch_size + ) + dummy_input.append(jnp.zeros(in_st.shape, in_st.dtype)) + else: + dummy_input.append(None) + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + @functools.partial(jax.jit, backend='cpu') + def _initialize_model(rngs): + """Initialization function to be jitted.""" + init_model_state, init_params = flax.core.pop( + flax.core.freeze( + model_def.init( + rngs, *dummy_input, train=train, debug=False, **model_kwargs + ) + ), + 'params', + ) + # Set bias in the head to low value, such that loss is small initially. + if config.get('init_head_bias', None) is not None: + init_params = flax.core.unfreeze(init_params) + init_params['output_projection'] = optimizers.tree_map_with_names( + lambda p: jnp.full_like(p, config.init_head_bias), + init_params['output_projection'], + match_name_fn=lambda name: 'bias' in name, + ) + init_params = flax.core.freeze(init_params) + return init_params, init_model_state + + if not isinstance(rngs, dict): + rngs = {'params': rngs} + init_params, init_model_state = _initialize_model(rngs) + # Pop out params rng: + rngs.pop('params') + + # Count number of trainable parameters: + num_trainable_params = debug_utils.log_param_shapes(init_params) + + # Count gflops: + count_flops = config.get( + 'count_flops', ml_collections.ConfigDict({'count_flops': True}) + ) + if count_flops: + variables = {'params': init_params, **init_model_state} + flops = debug_utils.compute_flops( + flax_model_apply_fn=functools.partial( + model_def.apply, + variables, + train=False, + debug=False, + rngs=rngs, + **model_kwargs, + ), + input_spec=count_flops.get('input_spec', input_spec), + fuse_multiply_add=count_flops.get('fuse_multiply_add', True), + ) + gflops = flops / (10**9) + else: + gflops = None + + return init_params, init_model_state, num_trainable_params, gflops + + +def initialize_model_with_pytree( + *, + model_def: nn.Module, + input_spec: PyTree, + config: ml_collections.ConfigDict, + rngs: Union[jnp.ndarray, Mapping[str, jnp.ndarray]], + unpack_input: bool = True, + **model_kwargs, +) -> Tuple[PyTree, PyTree, int, Optional[float]]: + """Initializes parameters and model state with a pytree input_spec. + + This is an extension of the above initialize_model function where we can put + pytree `input_spec`. We keep the original function for backward compatibility. + If the root type of `input_spec` is `Sequence`, each element is fed to the + model as position arguments whereas they are fed as keyword arguments if the + root type is `dict`. + + Args: + model_def: Definition of a model. + input_spec: A PyTree whose leaves are (shape, dtype) pairs specifying the + shape and dtype of the inputs. If unspecified the dtype is float32. + config: Configurations of the initialization. + rngs: Jax rng keys. + unpack_input: Unpack the pytree when feeding it to the model. + **model_kwargs: Kwargs passed to flax model initialization. + + Returns: + Initial params, Init model_state, number of trainable_params, and gflops. + """ + batch_size = ( + (config.batch_size // jax.device_count()) + if config.get('batch_size') + else None + ) + + def check_leaf_spec(spec: Sequence[PyTree]) -> bool: + return ( + len(spec) == 2 + and isinstance(spec[0], collections.Sequence) + and all(isinstance(i, int) for i in spec[0]) + and isinstance(spec[1], jnp.dtype) + ) or (all(isinstance(i, int) for i in spec[0])) + + def create_dummy_input(spec: PyTree) -> PyTree: + if isinstance(spec, dict): + return {k: create_dummy_input(v) for k, v in spec.items()} + elif isinstance(spec, collections.Sequence): + if check_leaf_spec(spec): + in_st = debug_utils.input_spec_to_jax_shape_dtype_struct( + spec, batch_size=batch_size + ) + return jnp.zeros(in_st.shape, in_st.dtype) + else: + return tuple(create_dummy_input(child) for child in spec) + elif spec is None: + return None + else: + raise NotImplementedError('Unsupported spec type.', type(spec)) + + dummy_input = create_dummy_input(input_spec) + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + @functools.partial(jax.jit, backend='cpu') + def _initialize_model(rngs): + """Initialization function to be jitted.""" + # If dummy_input is a dict, we feed inputs as keyword arguments, otherwise + # feed as position arguments. + if isinstance(dummy_input, dict) and unpack_input: + init_model_state, init_params = flax.core.pop( + flax.core.freeze( + model_def.init( + rngs, **dummy_input, train=False, debug=False, **model_kwargs + ) + ), + 'params', + ) + elif isinstance(dummy_input, collections.Sequence) and unpack_input: + init_model_state, init_params = flax.core.pop( + flax.core.freeze( + model_def.init( + rngs, *dummy_input, train=False, debug=False, **model_kwargs + ) + ), + 'params', + ) + else: + init_model_state, init_params = flax.core.pop( + flax.core.freeze( + model_def.init( + rngs, dummy_input, train=False, debug=False, **model_kwargs + ) + ), + 'params', + ) + # Set bias in the head to low value, such that loss is small initially. + if config.get('init_head_bias', None) is not None: + init_params = flax.core.unfreeze(init_params) + init_params['output_projection'] = optimizers.tree_map_with_names( + lambda p: jnp.full_like(p, config.init_head_bias), + init_params['output_projection'], + match_name_fn=lambda name: 'bias' in name, + ) + init_params = flax.core.freeze(init_params) + return init_params, init_model_state + + if not isinstance(rngs, dict): + rngs = {'params': rngs} + init_params, init_model_state = _initialize_model(rngs) + # Pop out params rng: + rngs.pop('params') + + # Count number of trainable parameters: + num_trainable_params = debug_utils.log_param_shapes(init_params) + + # Count gflops: + count_flops = config.get( + 'count_flops', ml_collections.ConfigDict({'count_flops': True}) + ) + if count_flops: + variables = {'params': init_params, **init_model_state} + flops = debug_utils.compute_flops_with_pytree( + flax_model_apply_fn=functools.partial( + model_def.apply, + variables, + train=False, + debug=False, + rngs=rngs, + **model_kwargs, + ), + input_spec=count_flops.get('input_spec', input_spec), + unpack_input=unpack_input, + fuse_multiply_add=count_flops.get('fuse_multiply_add', True), + ) + gflops = flops / (10**9) + else: + gflops = None + + return init_params, init_model_state, num_trainable_params, gflops + + +def get_dataset( + config: ml_collections.ConfigDict, + data_rng: PRNGKey, + *, + num_local_shards: Optional[int] = None, + dataset_service_address: Optional[str] = None, + dataset_name: Optional[str] = None, + dataset_configs: Optional[ml_collections.ConfigDict] = None, + **kwargs: Any, +) -> dataset_utils.Dataset: + """Creates dataset. + + By default, the values in the config file are used. + However, if the optional `dataset_name` and `dataset_configs` are passed, + those are used instead. + + Args: + config: The configuration of the experiment. + data_rng: Random number generator key to use for the dataset. + num_local_shards: Number of shards for each batch. So (bs, ...) becomes + (num_local_shards, bs//num_local_shards, ...). If not specified, it will + be number of local devices. + dataset_service_address: Used when using the tf.data.experimental.service + dataset_name: Name of dataset to load, if not reading from the config. + dataset_configs: Configuration of the dataset, if not reading directly from + the config. + **kwargs: Keyword arguments passed to the dataset builders. + + Returns: + A dataset_utils.Dataset object. + """ + device_count = jax.device_count() + logging.info('device_count: %d', device_count) + logging.info('num_hosts : %d', jax.process_count()) + logging.info('host_id : %d', jax.process_index()) + + dataset_name = dataset_name or config.dataset_name + dataset_builder = datasets.get_dataset(dataset_name) + + batch_size = config.batch_size + if batch_size % device_count > 0: + raise ValueError( + f'Batch size ({batch_size}) must be divisible by the ' + f'number of devices ({device_count})' + ) + + eval_batch_size = config.get('eval_batch_size', batch_size) + if eval_batch_size % device_count > 0: + raise ValueError( + f'Eval batch size ({eval_batch_size}) must be divisible ' + f'by the number of devices ({device_count})' + ) + + local_batch_size = batch_size // jax.process_count() + eval_local_batch_size = eval_batch_size // jax.process_count() + device_batch_size = batch_size // device_count + logging.info('local_batch_size : %d', local_batch_size) + logging.info('device_batch_size : %d', device_batch_size) + + shuffle_seed = config.get('shuffle_seed', None) + if dataset_service_address and shuffle_seed is not None: + raise ValueError( + 'Using dataset service with a random seed causes each ' + 'worker to produce exactly the same data. Add ' + 'config.shuffle_seed = None to your config if you want ' + 'to run with dataset service.' + ) + + dataset_configs = dataset_configs or config.get('dataset_configs', {}) + num_local_shards = num_local_shards or jax.local_device_count() + dataset = dataset_builder( + batch_size=local_batch_size, + eval_batch_size=eval_local_batch_size, + num_shards=num_local_shards, + dtype_str=config.data_dtype_str, + rng=data_rng, + shuffle_seed=shuffle_seed, + dataset_configs=dataset_configs, + dataset_service_address=dataset_service_address, + **kwargs, + ) + + return dataset + + +def initialize_multitask_model( + *, + model_def: nn.Module, + input_spec: Dict[ + Tuple[Tuple[str, Any], ...], + Sequence[Union[Tuple[Tuple[int, ...], jnp.dtype], Tuple[int, ...]]], + ], + config: ml_collections.ConfigDict, + rngs: Union[jnp.ndarray, Mapping[str, jnp.ndarray]], +) -> Tuple[PyTree, PyTree, int, Optional[Dict[str, float]]]: + """Initializes parameters and model state. + + Args: + model_def: Definition of a model. + input_spec: A dictionary from a dict of keyword arguments to an iterable of + (shape, dtype) pairs specifying the shape and dtype of the inputs. If + unspecified the dtype is float32. + config: Configurations of the initialization. + rngs: Jax rng keys. + + Returns: + Initial params, Init model_state, and number of trainable_params. + """ + + def init_fn(model_def): + for kwargs, in_spec in input_spec.items(): + if config.get('batch_sizes') is not None: + batch_size = config.batch_sizes.get(dict(kwargs)['dataset']) + else: + batch_size = config.batch_size + + batch_size = (batch_size // jax.device_count()) if batch_size else None + + input_shapetype = [ + debug_utils.input_spec_to_jax_shape_dtype_struct( + spec, batch_size=batch_size + ) + for spec in in_spec + ] + dummy_input = [] + for in_st in input_shapetype: + dummy_input.append(jnp.zeros(in_st.shape, in_st.dtype)) + model_def(*dummy_input, train=False, debug=False, **dict(kwargs)) + + # We want all parameters to be created in host RAM, not on any device, they'll + # be sent there later as needed, otherwise we already encountered two + # situations where we allocate them twice. + @functools.partial(jax.jit, backend='cpu') + def _initialize_model(rngs): + """Initialization function to be jitted.""" + init_model_state, init_params = flax.core.pop( + flax.core.freeze(nn.init(fn=init_fn, module=model_def)(rngs)), 'params' + ) + # Set bias in the head to low value, such that loss is small initially. + if ( + config.get('init_head_bias', None) is not None + and 'output_projection' in init_params + ): + init_params = flax.core.unfreeze(init_params) + init_params['output_projection'] = optimizers.tree_map_with_names( + lambda p: jnp.full_like(p, config.init_head_bias), + init_params['output_projection'], + match_name_fn=lambda name: 'bias' in name, + ) + init_params = flax.core.freeze(init_params) + return init_params, init_model_state + + if not isinstance(rngs, dict): + rngs = {'params': rngs} + init_params, init_model_state = _initialize_model(rngs) + # Pop out params rng: + rngs.pop('params') + + # Count number of trainable parameters: + num_trainable_params = debug_utils.log_param_shapes(init_params) + + # Count gflops: + count_flops = config.get('count_flops', ml_collections.ConfigDict()) + if count_flops: + variables = {'params': init_params, **init_model_state} + gflops_dict = {} + gflops_all = 0 + for kwargs, in_spec in input_spec.items(): + flops = debug_utils.compute_flops( + flax_model_apply_fn=functools.partial( + model_def.apply, + variables, + train=False, + debug=False, + rngs=rngs, + **dict(kwargs), + ), + input_spec=count_flops.get('input_spec', in_spec), + fuse_multiply_add=count_flops.get('fuse_multiply_add', True), + ) + gflops = flops / (10**9) + gflops_key = 'gflops/' + '/'.join(f'{x}={y}' for x, y in kwargs) + gflops_dict[gflops_key] = gflops + gflops_all += gflops + gflops_dict['gflops'] = gflops_all + else: + gflops_dict = None + + return init_params, init_model_state, num_trainable_params, gflops_dict + + +def get_num_training_steps( + config: ml_collections.ConfigDict, dataset_metadata: Dict[str, Any] +) -> Tuple[int, Optional[int]]: + """Calculates the total number of training step and possibly steps_per_epoch. + + The main raining loop is based on number of training steps. Thus, for datasets + that we want to train based on number of epochs, we need to calculate the + total number of training steps. This function looks for `num_training_steps` + in config, if it exists it returns that as the total step and `None` as + `steps_per_epoch`. If num_training_steps doesn't exist, then it looks for + `num_training_epochs` and given the size of training data calculates the total + steps and steps_per_epoch. In this computation, we assume that + drop_remainder=True. + + Args: + config: Configuration of the experiment. + dataset_metadata: Meta-data that is generated by the dataset_builder. + + Returns: + total_steps: Total number of training steps. + steps_per_epoch: Number of steps in every epoch. + """ + # We either use num_training_epochs or num_training_steps. + steps_per_epoch = ( + dataset_metadata.get('num_train_examples', 0) // config.batch_size + ) + + if config.get('num_training_steps') is not None: + assert not config.get('num_training_epochs') + return config.num_training_steps, steps_per_epoch or None + else: + assert config.num_training_epochs and not config.get('num_training_steps') + assert steps_per_epoch > 0, 'num_train_examples should be defined.' + return int(steps_per_epoch * config.num_training_epochs), steps_per_epoch + + +@functools.partial(jax.pmap, axis_name='x') +def pmap_mean(x: PyTree) -> PyTree: + # An axis_name is passed to pmap which can then be used by pmean. + # In this case each device has its own version of the batch statistics and + # we average them. + return jax.lax.pmean(x, 'x') + + +def sync_model_state_across_replicas(train_state: TrainState) -> TrainState: + """Sync the model_state (like batch statistics) across replicas. + + Args: + train_state: TrainState; Current state of training. + + Returns: + Updated state of training in which model_state is synced across replicas. + """ + # TODO(dehghani): We simply do "mean" here and this doesn't work with + # statistics like variance. (check the discussion in Flax for fixing this). + if jax.tree_util.tree_leaves(train_state.model_state): + # If the model_state is not empty. + new_model_state = flax.core.copy( + train_state.model_state, + {'batch_stats': pmap_mean(train_state.model_state['batch_stats'])}, + ) + return train_state.replace( # pytype: disable=attribute-error + model_state=new_model_state + ) + else: + return train_state + + +def save_checkpoint( + workdir: str, + train_state: TrainState, + max_to_keep: int = 3, + overwrite: bool = False, + **kwargs, +): + """Saves a checkpoint. + + Args: + workdir: Experiment directory for saving the checkpoint. + train_state: An instance of TrainState that holds the state of training. + max_to_keep: The number of checkpoints to keep. + overwrite: Overwrite existing checkpoint if a checkpoint at the current or + a later step already exits (default: False). + **kwargs: Passed on to flax.training.checkpoints.save_checkpoint. + """ + if jax.process_index() == 0: + # Get train state from the first replica. + checkpoint_state = jax.device_get(train_state) + checkpoints.save_checkpoint( + workdir, + checkpoint_state, + int(checkpoint_state.global_step), + overwrite=overwrite, + keep=max_to_keep, + **kwargs, + ) + + +SIGNED_FLOAT_RE = re.compile(r'([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)') + + +def checkpoint_path_step(path: str) -> Optional[float]: + """Returns the step number of a checkpoint path. + + Copied from flax/training/checkpoints.PyTree + + Args: + path: The path to the checkpoint. + + Returns: + The global step corresponding to that checkpoint, or None if it can't be + determined. + """ + for s in SIGNED_FLOAT_RE.split(path)[::-1]: + if SIGNED_FLOAT_RE.match(s): + return float(s) + return None + + +def restore_checkpoint( + checkpoint_path: str, + train_state: Optional[TrainState] = None, + assert_exist: bool = False, + step: Optional[int] = None, +) -> Tuple[TrainState, int]: + """Restores the last checkpoint. + + First restores the checkpoint, which is an instance of TrainState that holds + the state of training. + + Args: + checkpoint_path: Directory or filename to restore the checkpoint from. + train_state: An instance of TrainState that holds the state of training. + assert_exist: Assert that there is at least one checkpoint in the given + path. + step: Step number to load or None to load latest. If specified, + checkpoint_path must be a directory. + + Returns: + training state and an int which is the current step. + """ + if assert_exist: + if 'checkpoint_' in checkpoint_path.split('/')[-1]: + glob_path = checkpoint_path + else: + glob_path = os.path.join(checkpoint_path, 'checkpoint_*') + if not gfile.glob(glob_path): + raise ValueError( + 'No checkpoint for the pretrained model is found in: ' + f'{checkpoint_path}' + ) + if train_state is None: + raise ValueError( + 'Please use `restore_pretrained_checkpoint` for loading' + 'a checkpoint without providing a Scenic TrainState.' + ) + train_state = checkpoints.restore_checkpoint( + checkpoint_path, train_state, step + ) + return train_state, int(train_state.global_step) + + +def bind_rng_to_host_device( + rng: jnp.ndarray, + axis_name: Union[str, Tuple[str, ...]], + bind_to: Optional[str] = None, +) -> jnp.ndarray: + """Binds a rng to the host/device we are on. + + Must be called from within a pmapped function. Note that when binding to + "device", we also bind the rng to hosts, as we fold_in the rng with axis_index + which is unique for devices across all hosts. + + Args: + rng: A jax.random.PRNGKey. + axis_name: The axis of the devices we are binding rng across. + bind_to: Must be one of the 'host' or 'device'. None means no binding. + + Returns: + jax.random.PRNGKey specialized to host/device. + """ + if bind_to is None: + return rng + if bind_to == 'host': + return jax.random.fold_in(rng, jax.process_index()) + elif bind_to == 'device': + return jax.random.fold_in(rng, jax.lax.axis_index(axis_name)) + else: + raise ValueError( + "`bind_to` should be one of the `[None, 'host', 'device']`" + ) + + +class TrainingDivergedError(Exception): + pass + + +def normalize_metrics_summary( + metrics_summary: Dict[str, Tuple[float, int]], split: str +) -> Dict[str, float]: + """Normalize the metrics in summary by its normalizer. + + Args: + metrics_summary: A dictionary mapping metric name to (value, normalizer). + split: Split for which we normalize the metrics. Used for logging. + + Returns: + Normalized metrics summary. + + Raises: + TrainingDivergedError: Due to observing a NaN in the metrics. + """ + # TODO(dehghani): Currently we only support metrics of the form 1/N sum + # f(x_i). We may need a more general framework for metrics like + # precision and recall. Note in particular that while we're normalizing by + # the "metric normalization value" that is val[1], this value is previously + # summed up and is defined to be an integer. + normalized_metrics_summary = {} + for key, val in metrics_summary.items(): + normalized_metrics_summary[key] = val[0] / (val[1] + 1e-9) + if np.isnan(normalized_metrics_summary[key]): + msg = f'NaN detected in {split}_{key} (Unnormalized values: {val})' + if split == 'train': + raise TrainingDivergedError(msg) + else: + logging.error('WARNING: Split %s %s', split, msg) + + return normalized_metrics_summary + + +def stack_forest(forest: PyTree) -> PyTree: + """Transposes a list of dicts to dict of lists. + + For example, + given + [{'a':1,'b':2}, {'a':3,'b':4}], + the output is: + {'a': ([1, 3]), 'b': ([2, 4])} + + Args: + forest: a list of dicts + + Returns: + a dict of lists. + """ + if not forest: + return {} + + stack_args = lambda *args: np.stack(args) + return jax.tree_util.tree_map(stack_args, *forest) + + +def unreplicate_and_get(x: PyTree) -> PyTree: + return jax.device_get(jax_utils.unreplicate(x)) + + +def process_and_fetch_to_host( + pred_or_tgt: Union[jnp.ndarray, Dict[str, jnp.ndarray]], + batch_mask: jnp.ndarray, +) -> Union[Sequence[jnp.ndarray], Dict[str, jnp.ndarray]]: + """Used to collect predictions and targets of the whole valid/test set. + + Args: + pred_or_tgt: A jnp-array or dict of arrays, each of shape `[n_dev, bs, + X,...,Y]. + batch_mask: A nd-array of shape `[nun_devices, bs]`, where zero values + indicate padded examples. + + Returns: + A list of length n_dev*bs of items, where each item is a dictionary with + same keys as `pred_or_tgt` & values are normal np-arrays of shape [X,...,Y]. + """ + + def _split_mini_batches(x): + # Fetch to host and filter out padded examples. + x = jax.device_get(x)[np.array(batch_mask).astype(bool)] + # Split minibatch of examples into a list of examples. + x_list = jnp.split(x, x.shape[0], axis=0) + # Squeeze out the dummy dimension. + return jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), x_list) + + pred_or_tgt = jax.tree_util.tree_map(_split_mini_batches, pred_or_tgt) + + if isinstance(pred_or_tgt, list): + # Pred_or_tgt was a single array, so just return the list: + return pred_or_tgt + else: + # Pred_or_tgt was dict of arrays, so convert dict of lists to list of dicts: + keys, values = zip(*pred_or_tgt.items()) + return [dict(zip(keys, v)) for v in zip(*values)] # pytype: disable=bad-return-type # jax-ndarray + + +@functools.partial(jax.pmap, axis_name='i') +def _barrier(x): + return jax.lax.psum(x, axis_name='i') + + +def barrier(): + """MPI-like barrier.""" + jax.device_get(_barrier(jnp.ones((jax.local_device_count(),)))) + + +def log_eval_summary( + step: int, + *, + writer: metric_writers.MetricWriter, + eval_metrics: Sequence[Dict[str, Tuple[float, int]]], + extra_eval_summary: Optional[Mapping[str, float]] = None, + metrics_normalizer_fn: Optional[ + Callable[[Dict[str, Tuple[float, int]], str], Dict[str, float]] + ] = None, + prefix: str = 'valid', + key_separator: str = '_', + flush_writer: bool = True, +) -> Dict[str, float]: + """Computes and logs eval metrics. + + Args: + step: Current step. + writer: Metric writer object. + eval_metrics: List of dictionaries of calculated metrics. Usually the + sequence is the concatenation of the per-eval-step metrics, and every + dictionary maps a metric name to an array of (value, normalizer) - where + the array index is usually the batch index. + extra_eval_summary: A dict containing summaries that are already ready to be + logged, e.g. global metrics from eval set, like precision/recall. + metrics_normalizer_fn: Used for normalizing metrics. The API for this + function is: `new_metrics_dict = metrics_normalizer_fn(metrics_dict, + split)`. If set to None, we use the `normalize_metrics_summary` which uses + the normalizer paired with each metric to normalize it (after summing both + metric and normalizer values). + prefix: str; Prefix added to the name of the summaries writen by this + function. + key_separator: Separator added between the prefix and key. + flush_writer: If True, flush the writer after logging. + + Returns: + A dictionary of metrics, mapping both `eval_metrics` and + `extra_eval_summary` from metric name (incl. `prefix`) to float value. + """ + eval_metrics = stack_forest(eval_metrics) + + # Compute the sum over all examples in all batches. + eval_metrics_summary = jax.tree_util.tree_map(lambda x: x.sum(), eval_metrics) + # Normalize metrics by the total number of examples. + metrics_normalizer_fn = metrics_normalizer_fn or normalize_metrics_summary + eval_metrics_summary = metrics_normalizer_fn(eval_metrics_summary, 'eval') + # If None, set to an empty dictionary. + extra_eval_summary = extra_eval_summary or {} + + # Adds extra_eval_summary to the returned eval_summary. + eval_metrics_summary.update(extra_eval_summary) + + writer.write_scalars( + step, + { + key_separator.join((prefix, key)): val + for key, val in eval_metrics_summary.items() + }, + ) + + if flush_writer: + writer.flush() + return eval_metrics_summary + + +def log_train_summary( + step: int, + *, + writer: metric_writers.MetricWriter, + train_metrics: Sequence[Dict[str, Tuple[float, int]]], + extra_training_logs: Optional[Sequence[Dict[str, Any]]] = None, + metrics_normalizer_fn: Optional[ + Callable[[Dict[str, Tuple[float, int]], str], Dict[str, float]] + ] = None, + prefix: str = 'train', + key_separator: str = '_', + flush_writer: bool = True, +) -> Dict[str, float]: + """Computes and logs train metrics. + + Args: + step: Current step. + writer: Summary writer. + train_metrics: List of dictionaries of calculated metrics. Usually the + sequence is the concatenation of the per-eval-step metrics, and every + dictionary maps a metric name to an array of (value, normalizer) - where + the array index is usually the batch index. + extra_training_logs: List of dictionaries, containing additional training + logs, from every train step, e.g. learning rate, Time, num parameters, + etc. Their mean will be logged. + metrics_normalizer_fn: Used for normalizing metrics. The API for this + function is: `new_metrics_dict = metrics_normalizer_fn(metrics_dict, + split)`. If set to None, we use the normalize_metrics_summary which uses + the normalizer paired with each metric to normalize it. + prefix: str; Prefix added to the name of the summaries writen by this + function. + key_separator: Separator added between the prefix and key. + flush_writer: If True, flush the writer after logging. + + Returns: + A dictionary of metrics, mapping `train_metrics from metric name (incl. + `prefix`) to float value. + """ + ##### Prepare metrics: + # Get metrics from devices: + train_metrics = stack_forest(train_metrics) + # Compute the sum over all examples in all batches: + train_metrics_summary = jax.tree_util.tree_map( + lambda x: x.sum(), train_metrics + ) + # Normalize metrics by the total number of examples: + metrics_normalizer_fn = metrics_normalizer_fn or normalize_metrics_summary + train_metrics_summary = metrics_normalizer_fn(train_metrics_summary, 'train') + + ##### Prepare additional training logs: + # If None, set to an empty dictionary. + extra_training_logs = extra_training_logs or [{}] + train_logs = stack_forest(extra_training_logs) + + # Metrics: + writer.write_scalars( + step, + { + key_separator.join((prefix, key)): val + for key, val in train_metrics_summary.items() + }, + ) + # Additional logs: + writer.write_scalars( + step, {key: val.mean() for key, val in train_logs.items()} + ) + + if flush_writer: + writer.flush() + return train_metrics_summary + + +def accumulate_gradients( + compute_gradient_fn: Callable[ + [TrainState, Dict[str, jnp.ndarray], jnp.ndarray], + Tuple[Any, jnp.ndarray], + ], + metrics_fn: Callable[ + [jnp.ndarray, Dict[str, jnp.ndarray]], Dict[str, Tuple[float, int]] + ], + train_state: TrainState, + batch: Dict[str, jnp.ndarray], + dropout_rng: jnp.ndarray, + accum_steps: Optional[int], +) -> Tuple[ + Optional[jnp.ndarray], + jnp.ndarray, + jnp.ndarray, + Dict[str, Tuple[float, int]], +]: + """Accumulate gradients over multiple steps. + + This enables training with larger effective batch sizes. + Note that currently, gradient accumulation is not supported when the + `model_state` is used, e.g., for models that have batch normalization and + store batch statistics in the `model_state`. + + Note that if `accum_steps` <= 1 or is None, then the gradient of a single step + is simply returned. + + Args: + compute_gradient_fn: Gradient function, e.g., `jax.value_and_grad( + training_loss_fn, ...). + metrics_fn: A metrics function that given logits and batch of data, + calculates the metrics. + train_state: An instance of TrainState that has the parameters of the model, + state of the model, etc. + batch: A single batch of data. The buffer of this argument can be donated to + the computation. + dropout_rng: JAX rng key used for dropout. + accum_steps: Number of accumulating steps (number of micro batches). When + set to None or =<1, no accumulation is done. + + Returns: + A tuple of model_state (e.g., batch statistics), + computed gradients, training loss, and calculated metrics. + """ + params = train_state.params + if accum_steps and accum_steps > 1: + batch_size = next(iter(batch.values())).shape[0] + microbatch_size = batch_size // accum_steps + if batch_size % accum_steps != 0: + raise ValueError( + f'Bad accum_steps {accum_steps} for batch size {batch_size}' + ) + logging.info( + 'Using microbatches: %d microbatches, %d size', + accum_steps, + microbatch_size, + ) + + def get_microbatch( + batch: Dict[str, jnp.ndarray], idx: int + ) -> Dict[str, jnp.ndarray]: + """Fetch microbatch slice from the given batch.""" + return jax.tree_util.tree_map( + lambda x: x.reshape((-1, microbatch_size) + x.shape[1:])[idx], batch + ) + + def per_microbatch_compute_gradient_fn( + loop_cnt: int, + loop_state: Tuple[ + jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict[str, Tuple[float, int]] + ], + ) -> Tuple[ + jnp.ndarray, jnp.ndarray, Dict[str, Tuple[float, int]], jnp.ndarray + ]: + dropout_rng, grad_accum, train_loss_acc, metrics_acc = loop_state + dropout_rng, sub_dropout_rng = jax.random.split(dropout_rng) + mbatch = get_microbatch(batch, loop_cnt) + (train_loss, (_, mlogits)), grad = compute_gradient_fn( + params, mbatch, sub_dropout_rng + ) + metrics = metrics_fn(mlogits, mbatch) + # Accumulate gradients and metrics. + grad = jax.tree_util.tree_map(jnp.add, grad_accum, grad) + metrics = jax.tree_util.tree_map(jnp.add, metrics, metrics_acc) + train_loss = jax.tree_util.tree_map(jnp.add, train_loss, train_loss_acc) + return dropout_rng, grad, train_loss, metrics + + # Initialize gradient accumulation loop state. + dropout_rng, sub_dropout_rng = jax.random.split(dropout_rng) + init_mbatch = get_microbatch(batch, 0) + (init_train_loss, (model_state, init_logits)), grad_init = ( + compute_gradient_fn(params, init_mbatch, sub_dropout_rng) + ) + if jax.tree_util.tree_leaves(model_state): + # If the model_state is not empty. + raise ValueError( + 'Gradient accumulation is not supported when the ' + 'model_state is in used (e.g. models w/ batch norm).' + ) + + metrics_init = metrics_fn(init_logits, init_mbatch) + del init_logits, init_mbatch + + # Run gradient accumulation loop. + loop_init = (dropout_rng, grad_init, init_train_loss, metrics_init) + _, grad_acc, train_loss, metrics_acc = jax.lax.fori_loop( + 1, accum_steps, per_microbatch_compute_gradient_fn, loop_init + ) + grad_acc = jax.tree_util.tree_map(lambda x: x / accum_steps, grad_acc) + train_loss = jax.tree_util.tree_map(lambda x: x / accum_steps, train_loss) + return model_state, grad_acc, train_loss, metrics_acc + else: + (train_loss, (model_state, logits)), grad = compute_gradient_fn( + params, batch, dropout_rng + ) + metrics = metrics_fn(logits, batch) + return model_state, grad, train_loss, metrics + + +class Chrono: + """Measures time and reports progress. + + This is a modified fork of Chrono class from big_vision codebase: + https://github.com/google-research/big_vision/blob/main/big_vision/utils.py + + Some concepts: + 1. This differentiates between three "types" of time: + - training time: the time spent on actual training (fprop/bprop/update) + - program time: overall time the program runs, including all overheads + - pause time: the chronometer can be paused (eg during evals). + 2. This handles a "warmup": the first step is skipped for training time + purposes, as it includes significant compilation overheads, which distort + estimates. + 3. `accumulates` (i.e. integrates) timings, and saves/loads them across + restarts. + """ + + def __init__(self, example_type: str = 'img', warmup: int = 2): + self.program_start_time = time.monotonic() + self.train_start_time = None + self.train_start_step = None # When we started timing (after warmup) + + self.prev_time = None + self.prev_step = None + + self.pause_start = None + self.paused_time = 0 + + self.warmup = warmup # How many calls to `tick` to skip. + self.load() # Inits accum integrators. + self.note = 'Chrono n/a' + self.example_type = example_type + + def inform( + self, + first_step: int, + total_steps: int, + global_bs: int, + steps_per_epoch: int, + ): + """Provide some extra info that's only known later in the program.""" + self.prev_step = copy.deepcopy(first_step) + self.first_step = copy.deepcopy(first_step) + self.total_steps = total_steps + self.steps_per_epoch = steps_per_epoch + self.global_bs = global_bs + if total_steps: + self.note = ( + f'Steps:{first_step}/{total_steps} [{first_step/total_steps:.1%}]' + ) + + def tick( + self, + step: int, + writer: metric_writers.MetricWriter, + write_note: Callable[[str], None], + ): + """A chronometer tick.""" + summary = {} + + def hms(s): + """Format time in hours/minutes/seconds.""" + if s < 60: + return f'{s:.0f}s' + m, s = divmod(s, 60) + if m < 60: + return f'{m:.0f}m{s:.0f}s' + h, m = divmod(m, 60) + return f'{h:.0f}h{m:.0f}m' # Seconds intentionally omitted. + + now = time.monotonic() + summary.update({'uptime': now - self.program_start_time}) + # We always count examples, regardless of the timing-related warmup that + # happens a few lines below. + ds = step - self.prev_step # Steps between ticks + self.prev_step = step + self.accum_examples_seen += ds * self.global_bs + summary.update({'examples_seen': self.accum_examples_seen}) + if self.steps_per_epoch: + summary.update({'epoch': step / self.steps_per_epoch}) + + # We take the start as the second time `tick` is called, so we avoid + # measuring the overhead of compilation and don't include it in time + # estimates. + if self.warmup > 1: + self.warmup -= 1 + write_note(self.note) # This can help debugging. + return + if self.warmup == 1: + self.train_start_time = self.prev_time = now + self.train_start_step = step + self.accum_program_time += now - self.program_start_time + self.paused_time = 0 # Drop pauses that happened before timing starts. + self.warmup = 0 + write_note(self.note) # This can help debugging. + return + + # Measurement with micro-timings of current training steps speed. + # Time between ticks (ignoring pause) + if self.prev_time is None: + raise ValueError('prev_time is None, possible warmup was skipped') + dt = now - self.prev_time - self.paused_time + ncores = jax.device_count() # Global device count + summary.update({ + f'{self.example_type}/sec/core': self.global_bs * ds / dt / ncores, + f'{self.example_type}/sec': self.global_bs * ds / dt, + }) + + # Accumulate (integrate) times, good for plots. + self.accum_train_time += dt + self.accum_pause_time += self.paused_time + self.accum_program_time += dt + self.paused_time + + # Convert to, and log as, core hours. + core_hours = self.accum_train_time * ncores / 60 / 60 + devtype = jax.devices()[0].device_kind + summary.update({ + f'core_hours_{devtype}': core_hours, + 'core_hours': core_hours, # For convenience as x-axis in sweeps. + }) + + # Progress note with "global" full-program average timings + # (eg in program-time minus warmup) + dt = now - self.train_start_time # Time elapsed since end of warmup. + steps_timed = step - self.train_start_step + steps_todo = self.total_steps - step + self.note = f'Steps:{step}/{self.total_steps} [{step/self.total_steps:.1%}]' + self.note += f'\nWalltime:{hms(self.accum_program_time)}' + self.note += f' ({hms(self.accum_pause_time)} Not-train)' + self.note += f'\nETA:{hms(dt / steps_timed * steps_todo)}' + self.note += ( + f'\nTotal train time:{hms(dt / steps_timed * self.total_steps)}' + ) + write_note(self.note) + writer.write_scalars(step, summary) + self.prev_time = now + self.paused_time = 0 + + def pause(self, wait_for=()): + assert self.pause_start is None, "Don't pause twice." + jax.block_until_ready(wait_for) + self.pause_start = time.monotonic() + + def resume(self): + assert self.pause_start is not None, 'Cannot resume without pausing first.' + self.paused_time += time.monotonic() - self.pause_start + self.pause_start = None + + def save(self): + return dict( + accum_program_time=self.accum_program_time, + accum_train_time=self.accum_train_time, + accum_pause_time=self.accum_pause_time, + accum_examples_seen=self.accum_examples_seen, + ) + + def load(self, ckpt={}): # pylint: disable=dangerous-default-value + self.accum_program_time = ckpt.get('accum_program_time', 0.0) + self.accum_train_time = ckpt.get('accum_train_time', 0.0) + self.accum_pause_time = ckpt.get('accum_pause_time', 0.0) + self.accum_examples_seen = ckpt.get('accum_examples_seen', 0) + + +def barrier_across_hosts(): + """Ensure all hosts stay up until the end, otherwise the program may hang.""" + if jax.process_count() > 1: + x = jnp.ones([jax.local_device_count()]) + x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)) + assert x[0] == jax.device_count() + + +def handle_checkpointing( + train_state: TrainState, + chrono: Chrono, + workdir: str, + max_checkpoints_to_keep=3, +): + """Handles all the bookkeeping around checkpointing. + + Syncs the training state and unreplicates it, stops & restarts Chrono + (and handles its metadata) and writes the actual checkpoint. + + Args: + train_state: A replicated TrainState. + chrono: The Chrono object. + workdir: the workdir of the process. + max_checkpoints_to_keep: how many checkpoints to keep. + """ + train_state = sync_model_state_across_replicas(train_state) + if jax.process_index() == 0: + unrep_train_state = jax_utils.unreplicate(train_state) + metadata = unrep_train_state.metadata + metadata['chrono'] = chrono.save() + unrep_train_state = unrep_train_state.replace(metadata=metadata) + save_checkpoint( + workdir, unrep_train_state, max_to_keep=max_checkpoints_to_keep + ) + del unrep_train_state diff --git a/scenic/train_lib/trainers.py b/scenic/train_lib/trainers.py new file mode 100644 index 0000000000000000000000000000000000000000..4369a868b3b4895e6b3c897bcd02ed562d8ade7a --- /dev/null +++ b/scenic/train_lib/trainers.py @@ -0,0 +1,48 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry for the available trainers.""" + +from scenic.train_lib import classification_trainer +from scenic.train_lib.transfer import transfer_trainer + +ALL_TRAINERS = { + 'classification_trainer': classification_trainer.train, + 'transfer_trainer': transfer_trainer.train, +} + + +def get_trainer(train_fn_name): + """Get the corresponding trainer function. + + The returned train function has the following API: + ``` + train_state, train_summary, eval_summary = train_fn( + rng, model_cls, dataset, config, workdir, summary_writer) + ``` + Where the train_state is a checkpointable state of training and train_summary, + and eval_summary are python dictionary that contains metrics. + + Args: + train_fn_name: str; Name of the train_fn_name, e.g. + 'classification_trainer'. + + Returns: + The train function. + Raises: + ValueError if train_fn_name is unrecognized. + """ + if train_fn_name not in ALL_TRAINERS.keys(): + raise ValueError('Unrecognized trainer: {}'.format(train_fn_name)) + return ALL_TRAINERS[train_fn_name] diff --git a/scenic/train_lib/transfer/__init__.py b/scenic/train_lib/transfer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/transfer/fewshot_utils.py b/scenic/train_lib/transfer/fewshot_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/transfer/linear_probe_utils.py b/scenic/train_lib/transfer/linear_probe_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenic/train_lib/transfer/transfer_trainer.py b/scenic/train_lib/transfer/transfer_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..8f1c0ee27fd22f373b755bbf142dd96358c5f1d6 --- /dev/null +++ b/setup.py @@ -0,0 +1,120 @@ +# Copyright 2024 The Scenic Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""setup.py for Scenic. + +Install for development: + + pip intall -e . .[testing] +""" + +import os +import urllib.request + +from setuptools import Command +from setuptools import find_packages +from setuptools import setup +from setuptools.command import install + +SIMCLR_DIR = "simclr/tf2" +DATA_UTILS_URL = "https://raw.githubusercontent.com/google-research/simclr/master/tf2/data_util.py" + + +class DownloadSimCLRAugmentationCommand(Command): + """Downloads SimCLR data_utils.py as it's not built into an egg.""" + description = __doc__ + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + build_cmd = self.get_finalized_command("build") + dist_root = os.path.realpath(build_cmd.build_lib) + output_dir = os.path.join(dist_root, SIMCLR_DIR) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_path = os.path.join(output_dir, "data_util.py") + downloader = urllib.request.URLopener() + downloader.retrieve(DATA_UTILS_URL, output_path) + + +class InstallCommand(install.install): + + def run(self): + self.run_command("simclr_download") + install.install.run(self) + + +install_requires_projects = [ + "ott-jax>=0.2.0", + "sklearn", + "lingvo==0.12.6", + "seaborn>=0.11.2", + "dmvr @ git+https://github.com/google-deepmind/dmvr.git", +] + +install_requires_core = [ + "absl-py>=1.0.0", + "numpy>=1.12", + "jax>=0.4.3", + "jaxlib>=0.4.3", + "flax>=0.4.0", + "ml-collections>=0.1.1", + "tensorflow>=2.7", + "immutabledict>=2.2.1", + "clu>=0.0.6", + "tensorflow-datasets", + "optax @ git+https://github.com/google-deepmind/optax.git@main", +] + +tests_require = [ + "pytest", + "shapely", +] + install_requires_projects + +setup( + name="scenic", + version="0.0.1", + description=("A Jax Library for Computer Vision Research and Beyond."), + author="Scenic Authors", + author_email="no-reply@google.com", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + url="http://github.com/google-research/scenic", + license="Apache 2.0", + packages=find_packages(), + include_package_data=True, + install_requires=install_requires_core, + cmdclass={ + "simclr_download": DownloadSimCLRAugmentationCommand, + "install": InstallCommand, + }, + tests_require=tests_require, + extras_require={ + "testing": tests_require, + }, + classifiers=[ + "Development Status :: 1 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + keywords="Scenic", +)