## MODEL.py import pytorch_lightning as pl import segmentation_models_pytorch as smp import torch import torchmetrics class ModelRoiLeish(pl.LightningModule): def __init__(self, arch, encoder_name, in_channels, out_classes, lr=0.00001, **kwargs): super().__init__() self.model = smp.create_model( arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs ) # preprocessing parameteres for image params = smp.encoders.get_preprocessing_params(encoder_name) self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1)) self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1)) # for image segmentation dice loss could be the best first choice self.loss_fn = smp.losses.FocalLoss(smp.losses.BINARY_MODE) self.lr = lr self.save_hyperparameters('lr', 'arch', 'encoder_name') # vai predizer a imagem def forward(self, image): # normalize image here image = (image - self.mean) / self.std mask = self.model(image) return mask def shared_step(self, batch, stage): image = batch["image"] # Shape of the image should be (batch_size, num_channels, height, width) # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width] assert image.ndim == 4 # Check that image dimensions are divisible by 32, # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80 # and we will get an error trying to concat these features h, w = image.shape[2:] assert h % 32 == 0 and w % 32 == 0 mask = batch["mask"] # Shape of the mask should be [batch_size, num_classes, height, width] # for binary segmentation num_classes = 1 assert mask.ndim == 4 # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation assert mask.max() <= 1.0 and mask.min() >= 0 logits_mask = self.forward(image) # Predicted mask contains logits, and loss_fn param `from_logits` is set to True loss = self.loss_fn(logits_mask, mask) # Lets compute metrics for some threshold # first convert mask values to probabilities, then # apply thresholding prob_mask = logits_mask.sigmoid() iou_score = torchmetrics.functional.jaccard_index(prob_mask, mask.long()) pred_mask = (prob_mask > 0.5).float() # We will compute IoU metric by two ways # 1. dataset-wise # 2. image-wise # but for now we just compute true positive, false positive, false negative and # true negative 'pixels' for each image and class # these values will be aggregated in the end of an epoch tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary") loss_metrics = { f"{stage}_loss": loss.to(torch.float32).mean(), f"{stage}_tp": tp.to(torch.float32).mean(), f"{stage}_fp": fp.to(torch.float32).mean(), f"{stage}_fn": fn.to(torch.float32).mean(), f"{stage}_tn": tn.to(torch.float32).mean(), f"{stage}_jaccard": iou_score.to(torch.float32).mean() } self.log_dict(loss_metrics, prog_bar=True) return { "loss": loss, "tp": tp, "fp": fp, "fn": fn, "tn": tn, } def shared_epoch_end(self, outputs, stage): # aggregate step metics tp = torch.cat([x["tp"] for x in outputs]) fp = torch.cat([x["fp"] for x in outputs]) fn = torch.cat([x["fn"] for x in outputs]) tn = torch.cat([x["tn"] for x in outputs]) # per image IoU means that we first calculate IoU score for each image # and then compute mean over these scores per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise") # dataset IoU means that we aggregate intersection and union over whole dataset # and then compute IoU score. The difference between dataset_iou and per_image_iou scores # in this particular case will not be much, however for dataset # with "empty" images (images without target class) a large gap could be observed. # Empty images influence a lot on per_image_iou and much less on dataset_iou. dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro") metrics = { f"{stage}_per_image_iou": per_image_iou, f"{stage}_dataset_iou": dataset_iou, } self.log_dict(metrics, prog_bar=True) def training_step(self, batch, batch_idx): return self.shared_step(batch, "train") def training_epoch_end(self, outputs): return self.shared_epoch_end(outputs, "train") def validation_step(self, batch, batch_idx): return self.shared_step(batch, "valid") def validation_epoch_end(self, outputs): return self.shared_epoch_end(outputs, "valid") def test_step(self, batch, batch_idx): return self.shared_step(batch, "test") def test_epoch_end(self, outputs): return self.shared_epoch_end(outputs, "test") def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=self.lr)