Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import lightning as pl | |
| import torchvision.models as models | |
| from torchmetrics import Accuracy | |
| class Classifier(pl.LightningModule): | |
| def __init__(self, transfer=True): | |
| super(Classifier, self).__init__() | |
| self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1) # Adjust input channels to 3 | |
| self.model = models.efficientnet_b1(weights='IMAGENET1K_V1') | |
| if transfer: | |
| # layers are frozen by using eval() | |
| self.model.eval() | |
| # freeze params | |
| for p in self.model.parameters() : | |
| p.requires_grad = False | |
| num_ftrs = 1280 | |
| self.model.classifier = nn.Sequential( | |
| nn.LeakyReLU(), | |
| nn.Dropout(p=0.3), | |
| nn.Linear(in_features=num_ftrs , out_features=2), | |
| nn.Softmax(dim=1) | |
| ) | |
| self.criterion = nn.CrossEntropyLoss() | |
| self.train_accuracy = Accuracy(task='binary') | |
| self.val_accuracy = Accuracy(task='binary') | |
| def forward(self, x): | |
| x = self.conv(x) | |
| return self.model(x) | |
| def training_step(self, batch, batch_idx): | |
| images, labels = batch | |
| outputs = self(images) | |
| loss = self.criterion(outputs, labels) | |
| self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True) | |
| # Calculate and log accuracy | |
| _, preds = torch.max(outputs, 1) | |
| acc = self.train_accuracy(preds, labels) | |
| self.log('train_acc', acc, prog_bar=True, on_step=True, on_epoch=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| images, labels = batch | |
| outputs = self(images) | |
| loss = self.criterion(outputs, labels) | |
| self.log('val_loss', loss, prog_bar=True, sync_dist=True) | |
| # Calculate and log accuracy | |
| _, preds = torch.max(outputs, 1) | |
| acc = self.val_accuracy(preds, labels) | |
| self.log('val_acc', acc, prog_bar=True, sync_dist=True) | |
| return loss | |
| def on_train_epoch_end(self): | |
| self.train_accuracy.reset() | |
| def on_validation_epoch_end(self): | |
| self.val_accuracy.reset() | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam(self.parameters(), lr=0.0001) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=True) | |
| return { | |
| 'optimizer': optimizer, | |
| 'lr_scheduler': { | |
| 'scheduler': scheduler, | |
| 'monitor': 'val_loss', | |
| }, | |
| 'monitor': 'val_loss' | |
| } |