| | import lightning as L |
| | import torch |
| | import torch.nn.functional as F |
| | import timm |
| | from torch import optim |
| | from torchmetrics import Accuracy |
| | from torch.optim.lr_scheduler import OneCycleLR |
| |
|
| |
|
| | class ImageNetClassifier(L.LightningModule): |
| | def __init__(self, lr: float = 1e-3): |
| | super().__init__() |
| | self.lr = lr |
| | |
| | self.model = timm.create_model('resnet50', pretrained=False, num_classes=1000) |
| | self.train_acc = Accuracy(task="multiclass", num_classes=1000) |
| | self.val_acc = Accuracy(task="multiclass", num_classes=1000) |
| | self.save_hyperparameters() |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| | def training_step(self, batch, batch_idx): |
| | x, y = batch |
| | logits = self(x) |
| | loss = F.cross_entropy(logits, y) |
| | preds = F.softmax(logits, dim=1) |
| | self.train_acc(preds, y) |
| | self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True) |
| | self.log("train_acc", self.train_acc, prog_bar=True, on_step=False, on_epoch=True) |
| | return loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | x, y = batch |
| | logits = self(x) |
| | loss = F.cross_entropy(logits, y) |
| | preds = F.softmax(logits, dim=1) |
| | self.val_acc(preds, y) |
| | self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) |
| | self.log("val_acc", self.val_acc, prog_bar=True, on_step=False, on_epoch=True) |
| |
|
| | def configure_optimizers(self): |
| | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) |
| |
|
| | |
| | total_steps = self.trainer.estimated_stepping_batches |
| |
|
| | scheduler = OneCycleLR( |
| | optimizer, |
| | max_lr=self.lr, |
| | total_steps=total_steps, |
| | pct_start=0.3, |
| | div_factor=25, |
| | final_div_factor=1e4, |
| | three_phase=False, |
| | anneal_strategy='cos' |
| | ) |
| |
|
| | return { |
| | "optimizer": optimizer, |
| | "lr_scheduler": { |
| | "scheduler": scheduler, |
| | "interval": "step" |
| | } |
| | } |