owlv2 / scenic /projects /token_learner /configs /im1k_token_learner_config.py
fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=line-too-long
r"""Default configs for ViT on ImageNet2012.
```
"""
# pylint: disable=line-too-long
import ml_collections
_IMAGENET_TRAIN_SIZE = 1281167
VARIANT = 'B/16'
def get_config(runlocal=''):
"""Returns the ViT experiment configuration for ImageNet."""
runlocal = bool(runlocal)
config = ml_collections.ConfigDict()
config.experiment_name = 'imagenet-vit'
# Dataset.
config.dataset_name = 'imagenet'
config.data_dtype_str = 'float32'
config.dataset_configs = ml_collections.ConfigDict()
# Model.
version, patch = VARIANT.split('/')
config.model_name = 'token_learner_multilabel_classification'
config.model = ml_collections.ConfigDict()
config.model.hidden_size = {'Ti': 192,
'S': 384,
'B': 768,
'L': 1024,
'H': 1280}[version]
config.model.tokenizer = ml_collections.ConfigDict()
config.model.tokenizer.type = 'dynamic' # Set this to 'dynamic' to use TokenLearner
config.model.tokenizer.patches = ml_collections.ConfigDict()
config.model.tokenizer.patches.size = [int(patch), int(patch)]
config.model.tokenizer.num_tokens = 16 # Number of tokens to learn.
config.model.tokenizer.tokenlearner_loc = 9 # The layer to insert TokenLearner at. Must be between [0, config.model.num_layers). Change this to control the accuracy/computation trade-off
config.model.tokenizer.use_tokenfuse = False # Whether to use TokenFuser as well.
config.model.tokenizer.use_v11 = True # Whether to use TokenLearner V1.1. If False, uses the original TokenLearner module.
config.model.num_heads = {'Ti': 3, 'S': 6, 'B': 12, 'L': 16, 'H': 16}[version]
config.model.mlp_dim = {'Ti': 768,
'S': 1536,
'B': 3072,
'L': 4096,
'H': 5120}[version]
config.model.num_layers = {'Ti': 12,
'S': 12,
'B': 12,
'L': 24,
'H': 32}[version]
config.model.representation_size = None
config.model.classifier = 'gap'
config.model.attention_dropout_rate = 0.0
config.model.dropout_rate = 0.0
config.model.stochastic_depth = 0.1
config.model_dtype_str = 'float32'
# Training.
config.trainer_name = 'classification_trainer'
config.optimizer = 'adam'
config.optimizer_configs = ml_collections.ConfigDict()
config.optimizer_configs.beta1 = 0.9
config.optimizer_configs.beta2 = 0.999
config.optimizer_configs.weight_decay = 0.3
config.explicit_weight_decay = None # No explicit weight decay
config.l2_decay_factor = None
config.max_grad_norm = 1.0
config.label_smoothing = None
config.num_training_epochs = 90
config.log_eval_steps = 1000
config.batch_size = 8 if runlocal else 4096
config.rng_seed = 42
config.init_head_bias = -10.0
# Learning rate.
steps_per_epoch = _IMAGENET_TRAIN_SIZE // config.batch_size
total_steps = config.num_training_epochs * steps_per_epoch
base_lr = 5e-3
config.lr_configs = ml_collections.ConfigDict()
config.lr_configs.learning_rate_schedule = 'compound'
config.lr_configs.factors = 'constant*linear_warmup*linear_decay'
config.lr_configs.total_steps = total_steps
config.lr_configs.end_learning_rate = 1e-5
config.lr_configs.warmup_steps = 10_000
config.lr_configs.base_learning_rate = base_lr
# Logging.
config.write_summary = True
config.xprof = True # Profile using xprof.
config.checkpoint = True # Do checkpointing.
config.checkpoint_steps = 5000
config.debug_train = False # Debug mode during training.
config.debug_eval = False # Debug mode during eval.
return config