Spaces:
Runtime error
Runtime error
| """ | |
| Parts of this file have been adapted from | |
| https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html | |
| """ | |
| import pytorch_lightning as pl | |
| import torch.nn.functional as F | |
| from argparse import ArgumentParser | |
| from torch import Tensor | |
| from torch.optim import AdamW, Optimizer, RAdam | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| from transformers import get_scheduler, PreTrainedModel | |
| class ImageClassificationNet(pl.LightningModule): | |
| def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: | |
| parser = parent_parser.add_argument_group("Classification Model") | |
| parser.add_argument( | |
| "--optimizer", | |
| type=str, | |
| default="AdamW", | |
| choices=["AdamW", "RAdam"], | |
| help="The optimizer to use to train the model.", | |
| ) | |
| parser.add_argument( | |
| "--weight_decay", | |
| type=float, | |
| default=1e-2, | |
| help="The optimizer's weight decay.", | |
| ) | |
| parser.add_argument( | |
| "--lr", | |
| type=float, | |
| default=5e-5, | |
| help="The initial learning rate for the model.", | |
| ) | |
| return parent_parser | |
| def __init__( | |
| self, | |
| model: PreTrainedModel, | |
| num_train_steps: int, | |
| optimizer: str = "AdamW", | |
| weight_decay: float = 1e-2, | |
| lr: float = 5e-5, | |
| ): | |
| """A PyTorch Lightning Module for a HuggingFace model used for image classification. | |
| Args: | |
| model (PreTrainedModel): a pretrained model for image classification | |
| num_train_steps (int): number of training steps | |
| optimizer (str): optimizer to use | |
| weight_decay (float): weight decay for optimizer | |
| lr (float): the learning rate used for training | |
| """ | |
| super().__init__() | |
| # Save the hyperparameters and the model | |
| self.save_hyperparameters(ignore=["model"]) | |
| self.model = model | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.model(x).logits | |
| def configure_optimizers(self) -> tuple[list[Optimizer], list[_LRScheduler]]: | |
| # Set the optimizer class based on the hyperparameter | |
| if self.hparams.optimizer == "AdamW": | |
| optim_class = AdamW | |
| elif self.hparams.optimizer == "RAdam": | |
| optim_class = RAdam | |
| else: | |
| raise Exception(f"Unknown optimizer {self.hparams.optimizer}") | |
| # Create the optimizer and the learning rate scheduler | |
| optimizer = optim_class( | |
| self.parameters(), | |
| weight_decay=self.hparams.weight_decay, | |
| lr=self.hparams.lr, | |
| ) | |
| lr_scheduler = get_scheduler( | |
| name="linear", | |
| optimizer=optimizer, | |
| num_warmup_steps=0, | |
| num_training_steps=self.hparams.num_train_steps, | |
| ) | |
| return [optimizer], [lr_scheduler] | |
| def _calculate_loss(self, batch: tuple[Tensor, Tensor], mode: str) -> Tensor: | |
| imgs, labels = batch | |
| preds = self.model(imgs).logits | |
| loss = F.cross_entropy(preds, labels) | |
| acc = (preds.argmax(dim=-1) == labels).float().mean() | |
| self.log(f"{mode}_loss", loss) | |
| self.log(f"{mode}_acc", acc) | |
| return loss | |
| def training_step(self, batch: tuple[Tensor, Tensor], _: Tensor) -> Tensor: | |
| loss = self._calculate_loss(batch, mode="train") | |
| return loss | |
| def validation_step(self, batch: tuple[Tensor, Tensor], _: Tensor): | |
| self._calculate_loss(batch, mode="val") | |
| def test_step(self, batch: tuple[Tensor, Tensor], _: Tensor): | |
| self._calculate_loss(batch, mode="test") | |