import lightning as L from lightning.pytorch.utilities.model_summary import ModelSummary import torch import torch.nn.functional as F import torch.nn as nn import torchmetrics from torchvision import transforms from typing import Optional class ClassifierModel(L.LightningModule): def __init__(self, model: nn.Module, image_size: int = 500, learning_rate: float = 1e-3, num_classes: int = 3, train_transform: Optional[transforms.Compose] = None, val_transform: Optional[transforms.Compose] = None) -> None: super().__init__() self.model = model self.learning_rate = learning_rate self.example_input_array = torch.Tensor(5, 3, image_size, image_size) self.f1_score = torchmetrics.F1Score(task='multiclass', num_classes=num_classes) self.train_transform = train_transform self.val_transform = val_transform def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) def print_summary(self) -> None: print(ModelSummary(self, max_depth=-1)) def configure_optimizers(self) -> torch.optim.Optimizer: return torch.optim.Adam(params=self.model.parameters(), lr=self.learning_rate) def training_step(self, batch: tuple, batch_idx: int) -> float: X, y = batch y_pred = self(X) loss = F.cross_entropy(y_pred, y) self.log_dict({'Train loss': loss, f'Train F1 score': self.f1_score(y_pred, y)}, on_step=False, on_epoch=True) return loss def validation_step(self, batch: tuple, batch_idx: int) -> float: X, y = batch y_pred = self(X) loss = F.cross_entropy(y_pred, y) self.log_dict({'Validation loss': loss, f'Validation F1 score': self.f1_score(y_pred, y)}, on_step=False, on_epoch=True) return loss