Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import pytorch_lightning as pl | |
| import torch | |
| from pl_bolts.datamodules import CIFAR10DataModule | |
| from torch import nn | |
| from torchmetrics import Accuracy | |
| from examples.cifar_ViT import Classifier, VisionTransformer | |
| from xformers.factory import xFormer, xFormerConfig | |
| from xformers.helpers.hierarchical_configs import ( | |
| BasicLayerConfig, | |
| get_hierarchical_configuration, | |
| ) | |
| # This is very close to the cifarViT example, and reuses a lot of the training code, only the model part is different. | |
| # There are many ways one can use xformers to write down a MetaFormer, for instance by | |
| # picking up the parts from `xformers.components` and implementing the model explicitly, | |
| # or by patching another existing ViT-like implementation. | |
| # This example takes another approach, as we define the whole model configuration in one go (dict structure) | |
| # and then use the xformers factory to generate the model. This obfuscates a lot of the model building | |
| # (though you can inspect the resulting implementation), but makes it trivial to do some hyperparameter search | |
| class MetaVisionTransformer(VisionTransformer): | |
| def __init__( | |
| self, | |
| steps, | |
| learning_rate=5e-3, | |
| betas=(0.9, 0.99), | |
| weight_decay=0.03, | |
| image_size=32, | |
| num_classes=10, | |
| dim=384, | |
| attention="scaled_dot_product", | |
| feedforward="MLP", | |
| residual_norm_style="pre", | |
| use_rotary_embeddings=True, | |
| linear_warmup_ratio=0.1, | |
| classifier=Classifier.GAP, | |
| ): | |
| super(VisionTransformer, self).__init__() | |
| # all the inputs are saved under self.hparams (hyperparams) | |
| self.save_hyperparameters() | |
| # Generate the skeleton of our hierarchical Transformer | |
| # - This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32) | |
| # - Please note that this does not match the L1 configuration in the paper, as this would correspond to repeated | |
| # layers. CIFAR pictures are too small for this config to be directly meaningful (although that would run) | |
| # - Any other related config would work, and the attention mechanisms don't have to be the same across layers | |
| base_hierarchical_configs = [ | |
| BasicLayerConfig( | |
| embedding=64, | |
| attention_mechanism=attention, | |
| patch_size=3, | |
| stride=2, | |
| padding=1, | |
| seq_len=image_size * image_size // 4, | |
| feedforward=feedforward, | |
| repeat_layer=1, | |
| ), | |
| BasicLayerConfig( | |
| embedding=128, | |
| attention_mechanism=attention, | |
| patch_size=3, | |
| stride=2, | |
| padding=1, | |
| seq_len=image_size * image_size // 16, | |
| feedforward=feedforward, | |
| repeat_layer=1, | |
| ), | |
| BasicLayerConfig( | |
| embedding=320, | |
| attention_mechanism=attention, | |
| patch_size=3, | |
| stride=2, | |
| padding=1, | |
| seq_len=image_size * image_size // 64, | |
| feedforward=feedforward, | |
| repeat_layer=1, | |
| ), | |
| BasicLayerConfig( | |
| embedding=512, | |
| attention_mechanism=attention, | |
| patch_size=3, | |
| stride=2, | |
| padding=1, | |
| seq_len=image_size * image_size // 256, | |
| feedforward=feedforward, | |
| repeat_layer=1, | |
| ), | |
| ] | |
| # Fill in the gaps in the config | |
| xformer_config = get_hierarchical_configuration( | |
| base_hierarchical_configs, | |
| residual_norm_style=residual_norm_style, | |
| use_rotary_embeddings=use_rotary_embeddings, | |
| mlp_multiplier=4, | |
| dim_head=32, | |
| ) | |
| # Now instantiate the metaformer trunk | |
| config = xFormerConfig(xformer_config) | |
| config.weight_init = "moco" | |
| print(config) | |
| self.trunk = xFormer.from_config(config) | |
| print(self.trunk) | |
| # The classifier head | |
| dim = base_hierarchical_configs[-1].embedding | |
| self.ln = nn.LayerNorm(dim) | |
| self.head = nn.Linear(dim, num_classes) | |
| self.criterion = torch.nn.CrossEntropyLoss() | |
| self.val_accuracy = Accuracy() | |
| def forward(self, x): | |
| x = self.trunk(x) | |
| x = self.ln(x) | |
| x = x.mean(dim=1) # mean over sequence len | |
| x = self.head(x) | |
| return x | |
| if __name__ == "__main__": | |
| pl.seed_everything(42) | |
| # Adjust batch depending on the available memory on your machine. | |
| # You can also use reversible layers to save memory | |
| REF_BATCH = 768 | |
| BATCH = 256 # lower if not enough GPU memory | |
| MAX_EPOCHS = 50 | |
| NUM_WORKERS = 4 | |
| GPUS = 1 | |
| torch.cuda.manual_seed_all(42) | |
| torch.manual_seed(42) | |
| # We'll use a datamodule here, which already handles dataset/dataloader/sampler | |
| # - See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html | |
| # for a full tutorial | |
| # - Please note that default transforms are being used | |
| dm = CIFAR10DataModule( | |
| data_dir="data", | |
| batch_size=BATCH, | |
| num_workers=NUM_WORKERS, | |
| pin_memory=True, | |
| ) | |
| image_size = dm.size(-1) # 32 for CIFAR | |
| num_classes = dm.num_classes # 10 for CIFAR | |
| # compute total number of steps | |
| batch_size = BATCH * GPUS | |
| steps = dm.num_samples // REF_BATCH * MAX_EPOCHS | |
| lm = MetaVisionTransformer( | |
| steps=steps, | |
| image_size=image_size, | |
| num_classes=num_classes, | |
| attention="scaled_dot_product", | |
| residual_norm_style="pre", | |
| feedforward="MLP", | |
| use_rotary_embeddings=True, | |
| ) | |
| trainer = pl.Trainer( | |
| gpus=GPUS, | |
| max_epochs=MAX_EPOCHS, | |
| precision=16, | |
| accumulate_grad_batches=REF_BATCH // BATCH, | |
| ) | |
| trainer.fit(lm, dm) | |
| # check the training | |
| trainer.test(lm, datamodule=dm) | |