| |
| import pytorch_lightning as pl |
| import segmentation_models_pytorch as smp |
| import torch |
| import torchmetrics |
|
|
| class ModelRoiLeish(pl.LightningModule): |
|
|
| def __init__(self, arch, encoder_name, in_channels, out_classes, lr=0.00001, **kwargs): |
| super().__init__() |
| self.model = smp.create_model( |
| arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs |
| ) |
|
|
| |
| params = smp.encoders.get_preprocessing_params(encoder_name) |
| self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1)) |
| self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1)) |
|
|
| |
| self.loss_fn = smp.losses.FocalLoss(smp.losses.BINARY_MODE) |
| self.lr = lr |
|
|
| self.save_hyperparameters('lr', 'arch', 'encoder_name') |
|
|
| |
| def forward(self, image): |
| |
| image = (image - self.mean) / self.std |
| mask = self.model(image) |
| return mask |
|
|
| def shared_step(self, batch, stage): |
| |
| image = batch["image"] |
|
|
| |
| |
| assert image.ndim == 4 |
|
|
| |
| |
| |
| |
| |
| h, w = image.shape[2:] |
| assert h % 32 == 0 and w % 32 == 0 |
|
|
| mask = batch["mask"] |
|
|
| |
| |
| assert mask.ndim == 4 |
|
|
| |
| assert mask.max() <= 1.0 and mask.min() >= 0 |
|
|
| logits_mask = self.forward(image) |
| |
| |
| loss = self.loss_fn(logits_mask, mask) |
|
|
| |
| |
| |
| prob_mask = logits_mask.sigmoid() |
| iou_score = torchmetrics.functional.jaccard_index(prob_mask, mask.long()) |
| pred_mask = (prob_mask > 0.5).float() |
|
|
| |
| |
| |
| |
| |
| |
| tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary") |
|
|
| loss_metrics = { |
| f"{stage}_loss": loss.to(torch.float32).mean(), |
| f"{stage}_tp": tp.to(torch.float32).mean(), |
| f"{stage}_fp": fp.to(torch.float32).mean(), |
| f"{stage}_fn": fn.to(torch.float32).mean(), |
| f"{stage}_tn": tn.to(torch.float32).mean(), |
| f"{stage}_jaccard": iou_score.to(torch.float32).mean() |
| } |
|
|
| self.log_dict(loss_metrics, prog_bar=True) |
| |
| return { |
| "loss": loss, |
| "tp": tp, |
| "fp": fp, |
| "fn": fn, |
| "tn": tn, |
| } |
|
|
| def shared_epoch_end(self, outputs, stage): |
| |
| tp = torch.cat([x["tp"] for x in outputs]) |
| fp = torch.cat([x["fp"] for x in outputs]) |
| fn = torch.cat([x["fn"] for x in outputs]) |
| tn = torch.cat([x["tn"] for x in outputs]) |
|
|
| |
| |
| per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise") |
| |
| |
| |
| |
| |
| |
| dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro") |
|
|
| metrics = { |
| f"{stage}_per_image_iou": per_image_iou, |
| f"{stage}_dataset_iou": dataset_iou, |
| } |
| |
| self.log_dict(metrics, prog_bar=True) |
|
|
| def training_step(self, batch, batch_idx): |
| return self.shared_step(batch, "train") |
|
|
| def training_epoch_end(self, outputs): |
| return self.shared_epoch_end(outputs, "train") |
|
|
| def validation_step(self, batch, batch_idx): |
| return self.shared_step(batch, "valid") |
|
|
| def validation_epoch_end(self, outputs): |
| return self.shared_epoch_end(outputs, "valid") |
|
|
| def test_step(self, batch, batch_idx): |
| return self.shared_step(batch, "test") |
|
|
| def test_epoch_end(self, outputs): |
| return self.shared_epoch_end(outputs, "test") |
|
|
| def configure_optimizers(self): |
| return torch.optim.AdamW(self.parameters(), lr=self.lr) |
|
|
|
|