PriyePrabhakar's picture
Add application file 1
0c717d3
raw
history blame
2.2 kB
import lightning as L
import torch
import torch.nn.functional as F
import timm
from torch import optim
from torchmetrics import Accuracy
from torch.optim.lr_scheduler import OneCycleLR
class ImageNetClassifier(L.LightningModule):
def __init__(self, lr: float = 1e-3):
super().__init__()
self.lr = lr
#self.model = timm.create_model('resnet18', pretrained=True, num_classes=10)
self.model = timm.create_model('resnet50', pretrained=False, num_classes=1000)
self.train_acc = Accuracy(task="multiclass", num_classes=1000)
self.val_acc = Accuracy(task="multiclass", num_classes=1000)
self.save_hyperparameters()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
preds = F.softmax(logits, dim=1)
self.train_acc(preds, y)
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
self.log("train_acc", self.train_acc, prog_bar=True, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
preds = F.softmax(logits, dim=1)
self.val_acc(preds, y)
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
self.log("val_acc", self.val_acc, prog_bar=True, on_step=False, on_epoch=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
# Calculate total steps
total_steps = self.trainer.estimated_stepping_batches
scheduler = OneCycleLR(
optimizer,
max_lr=self.lr,
total_steps=total_steps,
pct_start=0.3,
div_factor=25,
final_div_factor=1e4,
three_phase=False,
anneal_strategy='cos'
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step" # Update at every step
}
}