Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # CREDITS: | |
| # inspired by | |
| # https://github.com/nateraw/lightning-vision-transformer | |
| # which in turn references https://github.com/lucidrains/vit-pytorch | |
| # Orignal author: Sean Naren | |
| import math | |
| from enum import Enum | |
| import pytorch_lightning as pl | |
| import torch | |
| from pl_bolts.datamodules import CIFAR10DataModule | |
| from torch import nn | |
| from torchmetrics import Accuracy | |
| from xformers.factory import xFormer, xFormerConfig | |
| class Classifier(str, Enum): | |
| GAP = "gap" | |
| TOKEN = "token" | |
| class VisionTransformer(pl.LightningModule): | |
| def __init__( | |
| self, | |
| steps, | |
| learning_rate=5e-4, | |
| betas=(0.9, 0.99), | |
| weight_decay=0.03, | |
| image_size=32, | |
| num_classes=10, | |
| patch_size=2, | |
| dim=384, | |
| n_layer=6, | |
| n_head=6, | |
| resid_pdrop=0.0, | |
| attn_pdrop=0.0, | |
| mlp_pdrop=0.0, | |
| attention="scaled_dot_product", | |
| residual_norm_style="pre", | |
| hidden_layer_multiplier=4, | |
| use_rotary_embeddings=True, | |
| linear_warmup_ratio=0.1, | |
| classifier: Classifier = Classifier.TOKEN, | |
| ): | |
| super().__init__() | |
| # all the inputs are saved under self.hparams (hyperparams) | |
| self.save_hyperparameters() | |
| assert image_size % patch_size == 0 | |
| num_patches = (image_size // patch_size) ** 2 | |
| # A list of the encoder or decoder blocks which constitute the Transformer. | |
| xformer_config = [ | |
| { | |
| "block_type": "encoder", | |
| "num_layers": n_layer, | |
| "dim_model": dim, | |
| "residual_norm_style": residual_norm_style, | |
| "multi_head_config": { | |
| "num_heads": n_head, | |
| "residual_dropout": resid_pdrop, | |
| "use_rotary_embeddings": use_rotary_embeddings, | |
| "attention": { | |
| "name": attention, | |
| "dropout": attn_pdrop, | |
| "causal": False, | |
| }, | |
| }, | |
| "feedforward_config": { | |
| "name": "MLP", | |
| "dropout": mlp_pdrop, | |
| "activation": "gelu", | |
| "hidden_layer_multiplier": hidden_layer_multiplier, | |
| }, | |
| "position_encoding_config": { | |
| "name": "learnable", | |
| "seq_len": num_patches, | |
| "dim_model": dim, | |
| "add_class_token": classifier == Classifier.TOKEN, | |
| }, | |
| "patch_embedding_config": { | |
| "in_channels": 3, | |
| "out_channels": dim, | |
| "kernel_size": patch_size, | |
| "stride": patch_size, | |
| }, | |
| } | |
| ] | |
| # The ViT trunk | |
| config = xFormerConfig(xformer_config) | |
| self.vit = xFormer.from_config(config) | |
| print(self.vit) | |
| # The classifier head | |
| self.ln = nn.LayerNorm(dim) | |
| self.head = nn.Linear(dim, num_classes) | |
| self.criterion = torch.nn.CrossEntropyLoss() | |
| self.val_accuracy = Accuracy() | |
| def linear_warmup_cosine_decay(warmup_steps, total_steps): | |
| """ | |
| Linear warmup for warmup_steps, with cosine annealing to 0 at total_steps | |
| """ | |
| def fn(step): | |
| if step < warmup_steps: | |
| return float(step) / float(max(1, warmup_steps)) | |
| progress = float(step - warmup_steps) / float( | |
| max(1, total_steps - warmup_steps) | |
| ) | |
| return 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| return fn | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.AdamW( | |
| self.parameters(), | |
| lr=self.hparams.learning_rate, | |
| betas=self.hparams.betas, | |
| weight_decay=self.hparams.weight_decay, | |
| ) | |
| warmup_steps = int(self.hparams.linear_warmup_ratio * self.hparams.steps) | |
| scheduler = { | |
| "scheduler": torch.optim.lr_scheduler.LambdaLR( | |
| optimizer, | |
| self.linear_warmup_cosine_decay(warmup_steps, self.hparams.steps), | |
| ), | |
| "interval": "step", | |
| } | |
| return [optimizer], [scheduler] | |
| def forward(self, x): | |
| x = self.vit(x) | |
| x = self.ln(x) | |
| if self.hparams.classifier == Classifier.TOKEN: | |
| x = x[:, 0] # only consider the token, we're classifying anyway | |
| elif self.hparams.classifier == Classifier.GAP: | |
| x = x.mean(dim=1) # mean over sequence len | |
| x = self.head(x) | |
| return x | |
| def training_step(self, batch, _): | |
| x, y = batch | |
| y_hat = self(x) | |
| loss = self.criterion(y_hat, y) | |
| self.logger.log_metrics( | |
| { | |
| "train_loss": loss.mean(), | |
| "learning_rate": self.lr_schedulers().get_last_lr()[0], | |
| }, | |
| step=self.global_step, | |
| ) | |
| return loss | |
| def evaluate(self, batch, stage=None): | |
| x, y = batch | |
| y_hat = self(x) | |
| loss = self.criterion(y_hat, y) | |
| acc = self.val_accuracy(y_hat, y) | |
| if stage: | |
| self.log(f"{stage}_loss", loss, prog_bar=True) | |
| self.log(f"{stage}_acc", acc, prog_bar=True) | |
| def validation_step(self, batch, _): | |
| self.evaluate(batch, "val") | |
| def test_step(self, batch, _): | |
| self.evaluate(batch, "test") | |
| if __name__ == "__main__": | |
| pl.seed_everything(42) | |
| # Adjust batch depending on the available memory on your machine. | |
| # You can also use reversible layers to save memory | |
| REF_BATCH = 512 | |
| BATCH = 128 | |
| MAX_EPOCHS = 30 | |
| NUM_WORKERS = 4 | |
| GPUS = 1 | |
| # We'll use a datamodule here, which already handles dataset/dataloader/sampler | |
| # - See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html | |
| # for a full tutorial | |
| # - Please note that default transforms are being used | |
| dm = CIFAR10DataModule( | |
| data_dir="data", | |
| batch_size=BATCH, | |
| num_workers=NUM_WORKERS, | |
| pin_memory=True, | |
| ) | |
| image_size = dm.size(-1) # 32 for CIFAR | |
| num_classes = dm.num_classes # 10 for CIFAR | |
| # compute total number of steps | |
| batch_size = BATCH * GPUS | |
| steps = dm.num_samples // REF_BATCH * MAX_EPOCHS | |
| lm = VisionTransformer( | |
| steps=steps, | |
| image_size=image_size, | |
| num_classes=num_classes, | |
| attention="scaled_dot_product", | |
| classifier=Classifier.TOKEN, | |
| residual_norm_style="pre", | |
| use_rotary_embeddings=True, | |
| ) | |
| trainer = pl.Trainer( | |
| gpus=GPUS, | |
| max_epochs=MAX_EPOCHS, | |
| detect_anomaly=False, | |
| precision=16, | |
| accumulate_grad_batches=REF_BATCH // BATCH, | |
| ) | |
| trainer.fit(lm, dm) | |
| # check the training | |
| trainer.test(lm, datamodule=dm) | |