leishmaniaModel / model.py
vannynakamura's picture
Update model.py
7f8dcb3
## MODEL.py
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import torch
import torchmetrics
class ModelRoiLeish(pl.LightningModule):
def __init__(self, arch, encoder_name, in_channels, out_classes, lr=0.00001, **kwargs):
super().__init__()
self.model = smp.create_model(
arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
)
# preprocessing parameteres for image
params = smp.encoders.get_preprocessing_params(encoder_name)
self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))
# for image segmentation dice loss could be the best first choice
self.loss_fn = smp.losses.FocalLoss(smp.losses.BINARY_MODE)
self.lr = lr
self.save_hyperparameters('lr', 'arch', 'encoder_name')
# vai predizer a imagem
def forward(self, image):
# normalize image here
image = (image - self.mean) / self.std
mask = self.model(image)
return mask
def shared_step(self, batch, stage):
image = batch["image"]
# Shape of the image should be (batch_size, num_channels, height, width)
# if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
assert image.ndim == 4
# Check that image dimensions are divisible by 32,
# encoder and decoder connected by `skip connections` and usually encoder have 5 stages of
# downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have
# following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
# and we will get an error trying to concat these features
h, w = image.shape[2:]
assert h % 32 == 0 and w % 32 == 0
mask = batch["mask"]
# Shape of the mask should be [batch_size, num_classes, height, width]
# for binary segmentation num_classes = 1
assert mask.ndim == 4
# Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
assert mask.max() <= 1.0 and mask.min() >= 0
logits_mask = self.forward(image)
# Predicted mask contains logits, and loss_fn param `from_logits` is set to True
loss = self.loss_fn(logits_mask, mask)
# Lets compute metrics for some threshold
# first convert mask values to probabilities, then
# apply thresholding
prob_mask = logits_mask.sigmoid()
iou_score = torchmetrics.functional.jaccard_index(prob_mask, mask.long())
pred_mask = (prob_mask > 0.5).float()
# We will compute IoU metric by two ways
# 1. dataset-wise
# 2. image-wise
# but for now we just compute true positive, false positive, false negative and
# true negative 'pixels' for each image and class
# these values will be aggregated in the end of an epoch
tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")
loss_metrics = {
f"{stage}_loss": loss.to(torch.float32).mean(),
f"{stage}_tp": tp.to(torch.float32).mean(),
f"{stage}_fp": fp.to(torch.float32).mean(),
f"{stage}_fn": fn.to(torch.float32).mean(),
f"{stage}_tn": tn.to(torch.float32).mean(),
f"{stage}_jaccard": iou_score.to(torch.float32).mean()
}
self.log_dict(loss_metrics, prog_bar=True)
return {
"loss": loss,
"tp": tp,
"fp": fp,
"fn": fn,
"tn": tn,
}
def shared_epoch_end(self, outputs, stage):
# aggregate step metics
tp = torch.cat([x["tp"] for x in outputs])
fp = torch.cat([x["fp"] for x in outputs])
fn = torch.cat([x["fn"] for x in outputs])
tn = torch.cat([x["tn"] for x in outputs])
# per image IoU means that we first calculate IoU score for each image
# and then compute mean over these scores
per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
# dataset IoU means that we aggregate intersection and union over whole dataset
# and then compute IoU score. The difference between dataset_iou and per_image_iou scores
# in this particular case will not be much, however for dataset
# with "empty" images (images without target class) a large gap could be observed.
# Empty images influence a lot on per_image_iou and much less on dataset_iou.
dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
metrics = {
f"{stage}_per_image_iou": per_image_iou,
f"{stage}_dataset_iou": dataset_iou,
}
self.log_dict(metrics, prog_bar=True)
def training_step(self, batch, batch_idx):
return self.shared_step(batch, "train")
def training_epoch_end(self, outputs):
return self.shared_epoch_end(outputs, "train")
def validation_step(self, batch, batch_idx):
return self.shared_step(batch, "valid")
def validation_epoch_end(self, outputs):
return self.shared_epoch_end(outputs, "valid")
def test_step(self, batch, batch_idx):
return self.shared_step(batch, "test")
def test_epoch_end(self, outputs):
return self.shared_epoch_end(outputs, "test")
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=self.lr)