Spaces:
Sleeping
Sleeping
demo to spaces
Browse files- LICENSE +21 -0
- README.md +19 -13
- __pycache__/custom_unet.cpython-310.pyc +0 -0
- __pycache__/early_stopper.cpython-310.pyc +0 -0
- __pycache__/model.cpython-310.pyc +0 -0
- __pycache__/unet.cpython-310.pyc +0 -0
- app.py +55 -0
- custom_unet.py +135 -0
- early_stopper.py +23 -0
- examples/example_1.jpg +0 -0
- examples/example_2.jpg +0 -0
- examples/example_3.jpg +0 -0
- model.py +130 -0
- models/custom_unet.pth +3 -0
- models/unet.pth +3 -0
- requirements.txt +9 -0
- unet.py +67 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 i4ata
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,13 +1,19 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# UnetSegmentation
|
| 2 |
+
|
| 3 |
+
My own implementation of the U-net architecture compared to a pretrained model from PyTorch Segmentation Models.
|
| 4 |
+
|
| 5 |
+
To train:
|
| 6 |
+
|
| 7 |
+
```python train.py```
|
| 8 |
+
|
| 9 |
+
To visualize training:
|
| 10 |
+
|
| 11 |
+
```tesnorboard --logdir runs```
|
| 12 |
+
|
| 13 |
+
To visually compare models on some examples:
|
| 14 |
+
|
| 15 |
+
```python compare_models.py```
|
| 16 |
+
|
| 17 |
+
To launch a Gradio application:
|
| 18 |
+
|
| 19 |
+
```python3 gradio_app.py```
|
__pycache__/custom_unet.cpython-310.pyc
ADDED
|
Binary file (5.68 kB). View file
|
|
|
__pycache__/early_stopper.cpython-310.pyc
ADDED
|
Binary file (956 Bytes). View file
|
|
|
__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (4.59 kB). View file
|
|
|
__pycache__/unet.cpython-310.pyc
ADDED
|
Binary file (2.81 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from model import SegmentationModel
|
| 9 |
+
from custom_unet import CustomUnet
|
| 10 |
+
from unet import Unet
|
| 11 |
+
|
| 12 |
+
from typing import Dict, Union, Tuple, List
|
| 13 |
+
|
| 14 |
+
class GradioApp:
|
| 15 |
+
|
| 16 |
+
def __init__(self) -> None:
|
| 17 |
+
|
| 18 |
+
self.models: Dict[str, Union[str, SegmentationModel]] = {
|
| 19 |
+
'Custom': 'custom_unet',
|
| 20 |
+
'Pretrained': 'unet'
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
def predict(self, img_file: str, model_name: str) -> Tuple[str, List[Tuple[np.ndarray, str]]]:
|
| 24 |
+
|
| 25 |
+
# Lazy loading of models
|
| 26 |
+
if isinstance(self.models[model_name], str):
|
| 27 |
+
model_class = CustomUnet if model_name == 'Custom' else Unet
|
| 28 |
+
self.models[model_name] = model_class(self.models[model_name], from_file=True, device='cpu')
|
| 29 |
+
self.models[model_name].eval()
|
| 30 |
+
prediction = self.models[model_name].predict(img_file, option='mask')[0] * 1
|
| 31 |
+
return img_file, [(prediction, 'person')]
|
| 32 |
+
|
| 33 |
+
def launch(self):
|
| 34 |
+
|
| 35 |
+
examples_list = [['examples/' + example] for example in os.listdir('examples')]
|
| 36 |
+
|
| 37 |
+
demo = gr.Interface(
|
| 38 |
+
fn=self.predict,
|
| 39 |
+
inputs=[
|
| 40 |
+
gr.Image(type='filepath', label='Input image to segment'),
|
| 41 |
+
gr.Radio(choices=('Custom', 'Pretrained'), label='Available models')
|
| 42 |
+
],
|
| 43 |
+
outputs=gr.AnnotatedImage(label='Model predictions'),
|
| 44 |
+
examples=examples_list,
|
| 45 |
+
cache_examples=False,
|
| 46 |
+
#title='Plants Diseases Classification',
|
| 47 |
+
#description=f'This model performs classification on images of leaves that are either healthy, \
|
| 48 |
+
# have bean rust, or have an angular leaf spot. A vision transformer neural network architecture is used. \
|
| 49 |
+
# The dataset can be downloaded from [Kaggle]({dataset_url}) and the source code is on [GitHub]({github_repo_url}).',
|
| 50 |
+
)
|
| 51 |
+
demo.launch()
|
| 52 |
+
|
| 53 |
+
if __name__ == '__main__':
|
| 54 |
+
app = GradioApp()
|
| 55 |
+
app.launch()
|
custom_unet.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This python module impements the Unet architecture as defined in https://arxiv.org/pdf/1505.04597.
|
| 2 |
+
Only, I use padded convolutions. That way, there is no need for center cropping and the output mask
|
| 3 |
+
is the same shape as the input image.
|
| 4 |
+
|
| 5 |
+
Additional things: https://towardsdatascience.com/understanding-u-net-61276b10f360
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torchinfo import summary
|
| 11 |
+
|
| 12 |
+
from model import SegmentationModel
|
| 13 |
+
from early_stopper import EarlyStopper
|
| 14 |
+
|
| 15 |
+
from typing import Tuple, Union, Optional
|
| 16 |
+
|
| 17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 18 |
+
|
| 19 |
+
class DiceLoss(nn.Module):
|
| 20 |
+
|
| 21 |
+
def forward(self, logits: torch.Tensor, mask_true: torch.Tensor):
|
| 22 |
+
logits = torch.sigmoid(logits) > .5
|
| 23 |
+
intersection = (logits * mask_true).sum()
|
| 24 |
+
union = logits.sum() + mask_true.sum()
|
| 25 |
+
return 2 * intersection / union
|
| 26 |
+
|
| 27 |
+
class DoubleConv(nn.Module):
|
| 28 |
+
|
| 29 |
+
def __init__(self, in_channels: int, out_channels: int) -> None:
|
| 30 |
+
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.relu = nn.ReLU()
|
| 33 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
|
| 34 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor):
|
| 37 |
+
return self.relu(self.conv2(self.relu(self.conv1(x))))
|
| 38 |
+
|
| 39 |
+
class Up(nn.Module):
|
| 40 |
+
|
| 41 |
+
def __init__(self, in_channels, out_channels) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.upconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
|
| 44 |
+
self.conv = DoubleConv(in_channels=in_channels, out_channels=out_channels)
|
| 45 |
+
|
| 46 |
+
def forward(self, x_left, x_right):
|
| 47 |
+
return self.conv(torch.cat((x_left, self.upconv(x_right)), dim=1))
|
| 48 |
+
|
| 49 |
+
class UnetModel(nn.Module):
|
| 50 |
+
|
| 51 |
+
def __init__(self, in_channels: int = 3, depth: int = 3, start_channels: int = 16) -> None:
|
| 52 |
+
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
self.input_conv = DoubleConv(in_channels, start_channels)
|
| 56 |
+
|
| 57 |
+
self.encoder_layers = nn.ModuleList()
|
| 58 |
+
for i in range(depth):
|
| 59 |
+
self.encoder_layers.append(DoubleConv(start_channels, start_channels * 2))
|
| 60 |
+
start_channels *= 2
|
| 61 |
+
|
| 62 |
+
self.decoder_layers = nn.ModuleList()
|
| 63 |
+
for i in range(depth):
|
| 64 |
+
self.decoder_layers.append(Up(start_channels, start_channels // 2))
|
| 65 |
+
start_channels //= 2
|
| 66 |
+
|
| 67 |
+
self.output_conv = nn.Conv2d(start_channels, 1, kernel_size=1)
|
| 68 |
+
|
| 69 |
+
self.pool = nn.MaxPool2d(2, 2)
|
| 70 |
+
|
| 71 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 72 |
+
|
| 73 |
+
x = self.input_conv(x)
|
| 74 |
+
xs = [x]
|
| 75 |
+
|
| 76 |
+
for encoding_layer in self.encoder_layers:
|
| 77 |
+
x = encoding_layer(self.pool(x))
|
| 78 |
+
xs.append(x)
|
| 79 |
+
|
| 80 |
+
for decoding_layer, x_left in zip(self.decoder_layers, reversed(xs[:-1])):
|
| 81 |
+
x = decoding_layer(x_left, x)
|
| 82 |
+
|
| 83 |
+
return self.output_conv(x)
|
| 84 |
+
|
| 85 |
+
class CustomUnet(SegmentationModel):
|
| 86 |
+
|
| 87 |
+
def __init__(self,
|
| 88 |
+
name: str = 'default_name',
|
| 89 |
+
from_file: bool = True,
|
| 90 |
+
image_size: Tuple[int, int] = (320, 320),
|
| 91 |
+
in_channels: int = 3,
|
| 92 |
+
start_channels: int = 16,
|
| 93 |
+
encoder_depth: int = 5,
|
| 94 |
+
device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> None:
|
| 95 |
+
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
assert image_size[0] % (2**encoder_depth) == 0
|
| 99 |
+
assert image_size[1] % (2**encoder_depth) == 0
|
| 100 |
+
|
| 101 |
+
self.name = name
|
| 102 |
+
self.image_size = image_size
|
| 103 |
+
self.in_channels = in_channels
|
| 104 |
+
self.device = device
|
| 105 |
+
|
| 106 |
+
self.save_path = f'models/{name}.pth'
|
| 107 |
+
|
| 108 |
+
if from_file:
|
| 109 |
+
self.unet = torch.load(self.save_path, map_location=device)
|
| 110 |
+
else:
|
| 111 |
+
self.unet = UnetModel(in_channels=in_channels, depth=encoder_depth, start_channels=start_channels).to(device)
|
| 112 |
+
|
| 113 |
+
self.bce_loss = nn.BCEWithLogitsLoss()
|
| 114 |
+
self.dice_loss = DiceLoss()
|
| 115 |
+
self.loss_fn = lambda logits, masks: self.bce_loss(logits, masks) + self.dice_loss(logits, masks)
|
| 116 |
+
|
| 117 |
+
def configure_optimizers(self, **kwargs):
|
| 118 |
+
self.optimizer = torch.optim.Adam(params=self.unet.parameters(), lr=kwargs['lr'])
|
| 119 |
+
self.early_stopper = EarlyStopper(patience=kwargs['patience'])
|
| 120 |
+
|
| 121 |
+
def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 122 |
+
logits = self.unet(images)
|
| 123 |
+
if masks is None:
|
| 124 |
+
return logits
|
| 125 |
+
return logits, self.loss_fn(logits, masks)
|
| 126 |
+
|
| 127 |
+
def save(self) -> None:
|
| 128 |
+
# Save the whole model, not only the state dict, so that it will work for different unets
|
| 129 |
+
torch.save(self.unet, self.save_path)
|
| 130 |
+
|
| 131 |
+
def print_summary(self, batch_size: int = 16) -> None:
|
| 132 |
+
|
| 133 |
+
print(summary(self.unet, input_size=(batch_size, self.in_channels, *self.image_size),
|
| 134 |
+
col_names=['input_size', 'output_size', 'num_params'],
|
| 135 |
+
row_settings=['var_names']))
|
early_stopper.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module contains a class that implements early stopping regularization technique"""
|
| 2 |
+
|
| 3 |
+
class EarlyStopper:
|
| 4 |
+
|
| 5 |
+
def __init__(self, patience: int = 2):
|
| 6 |
+
|
| 7 |
+
self.patience = patience
|
| 8 |
+
self.best_loss = float('inf')
|
| 9 |
+
self.counter = 0
|
| 10 |
+
self.save_model = False
|
| 11 |
+
|
| 12 |
+
def check(self, validation_loss: float) -> bool:
|
| 13 |
+
|
| 14 |
+
self.save_model = False
|
| 15 |
+
if validation_loss > self.best_loss:
|
| 16 |
+
self.counter += 1
|
| 17 |
+
if self.counter == self.patience:
|
| 18 |
+
return True
|
| 19 |
+
else:
|
| 20 |
+
self.best_loss = validation_loss
|
| 21 |
+
self.counter = 0
|
| 22 |
+
self.save_model = True
|
| 23 |
+
return False
|
examples/example_1.jpg
ADDED
|
examples/example_2.jpg
ADDED
|
examples/example_3.jpg
ADDED
|
model.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module contains the base class for segmentation models"""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchvision.utils import draw_segmentation_masks
|
| 7 |
+
from torchvision.transforms.functional import resize
|
| 8 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2 as cv
|
| 12 |
+
import albumentations as A
|
| 13 |
+
|
| 14 |
+
from typing import Optional, Union, Tuple, Literal
|
| 15 |
+
|
| 16 |
+
from early_stopper import EarlyStopper
|
| 17 |
+
|
| 18 |
+
class SegmentationModel(nn.Module):
|
| 19 |
+
|
| 20 |
+
name: str = "base name"
|
| 21 |
+
device: Literal['cpu', 'cuda'] = None
|
| 22 |
+
|
| 23 |
+
optimizer: torch.optim.Optimizer = None
|
| 24 |
+
early_stopper: EarlyStopper = None
|
| 25 |
+
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None
|
| 26 |
+
save_path: str = None
|
| 27 |
+
image_size: Tuple[int, int] = None
|
| 28 |
+
|
| 29 |
+
def configure_optimizers(self, **kwargs) -> None:
|
| 30 |
+
raise NotImplementedError()
|
| 31 |
+
|
| 32 |
+
def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 33 |
+
|
| 34 |
+
raise NotImplementedError()
|
| 35 |
+
|
| 36 |
+
def _train_step(self, data_loader: DataLoader) -> float:
|
| 37 |
+
|
| 38 |
+
self.train()
|
| 39 |
+
total_loss = 0.
|
| 40 |
+
for images, masks in data_loader:
|
| 41 |
+
images, masks = images.to(self.device), masks.to(self.device)
|
| 42 |
+
|
| 43 |
+
self.optimizer.zero_grad()
|
| 44 |
+
logits, loss = self(images, masks)
|
| 45 |
+
loss.backward()
|
| 46 |
+
self.optimizer.step()
|
| 47 |
+
|
| 48 |
+
total_loss += loss.item()
|
| 49 |
+
|
| 50 |
+
return total_loss / len(data_loader)
|
| 51 |
+
|
| 52 |
+
def _test_step(self, data_loader: DataLoader) -> float:
|
| 53 |
+
|
| 54 |
+
self.eval()
|
| 55 |
+
total_loss = 0.
|
| 56 |
+
with torch.inference_mode():
|
| 57 |
+
for images, masks in data_loader:
|
| 58 |
+
images, masks = images.to(self.device), masks.to(self.device)
|
| 59 |
+
logits, loss = self(images, masks)
|
| 60 |
+
total_loss += loss.item()
|
| 61 |
+
return total_loss / len(data_loader)
|
| 62 |
+
|
| 63 |
+
def train_model(self, train_loader: DataLoader, test_loader: DataLoader, epochs: int, log_dir: str) -> None:
|
| 64 |
+
|
| 65 |
+
writer = SummaryWriter(log_dir=f'{log_dir}/{self.name}')
|
| 66 |
+
|
| 67 |
+
for i in range(epochs):
|
| 68 |
+
train_loss = self._train_step(train_loader)
|
| 69 |
+
test_loss = self._test_step(test_loader)
|
| 70 |
+
|
| 71 |
+
if self.early_stopper is not None:
|
| 72 |
+
if self.early_stopper.check(test_loss):
|
| 73 |
+
print(f'Model stopped early due to risk of overfitting')
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
if self.early_stopper.save_model:
|
| 77 |
+
self.save()
|
| 78 |
+
print('saved model')
|
| 79 |
+
|
| 80 |
+
if self.lr_scheduler is not None:
|
| 81 |
+
self.lr_scheduler.step()
|
| 82 |
+
|
| 83 |
+
print(f'{i}: Train loss: {train_loss :.2} | Test loss: {test_loss :.2}')
|
| 84 |
+
|
| 85 |
+
writer.add_scalars(main_tag='Loss over time',
|
| 86 |
+
tag_scalar_dict={'train loss': train_loss, 'test loss': test_loss},
|
| 87 |
+
global_step=i)
|
| 88 |
+
|
| 89 |
+
else:
|
| 90 |
+
if self.early_stopper is not None:
|
| 91 |
+
print('Model did not converge. Possibility of underfitting')
|
| 92 |
+
self.save()
|
| 93 |
+
writer.close()
|
| 94 |
+
|
| 95 |
+
def save(self) -> None:
|
| 96 |
+
raise NotImplementedError()
|
| 97 |
+
|
| 98 |
+
def predict(self,
|
| 99 |
+
test_image_path: str,
|
| 100 |
+
option: Literal['mask', 'image_with_mask', 'mask_and_image_with_mask'] = 'image_with_mask'
|
| 101 |
+
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
| 102 |
+
|
| 103 |
+
self.eval()
|
| 104 |
+
input_resizer = A.Resize(*self.image_size)
|
| 105 |
+
|
| 106 |
+
original_image = cv.cvtColor(cv.imread(test_image_path), cv.COLOR_BGR2RGB)
|
| 107 |
+
original_image_tensor = torch.from_numpy(original_image).permute(2,0,1).type(torch.uint8)
|
| 108 |
+
resized_image_tensor = (torch.from_numpy(input_resizer(image=original_image)['image']).float() / 255.).permute(2,0,1)
|
| 109 |
+
|
| 110 |
+
with torch.inference_mode():
|
| 111 |
+
logits = self(resized_image_tensor.unsqueeze(0).to(self.device)).squeeze(0).cpu().detach()
|
| 112 |
+
probs = torch.sigmoid(logits)
|
| 113 |
+
resized_mask_tensor = probs > .5
|
| 114 |
+
|
| 115 |
+
original_mask_tensor = resize(resized_mask_tensor, size=original_image.shape[:-1], antialias=True)
|
| 116 |
+
|
| 117 |
+
image_with_mask = draw_segmentation_masks(image=original_image_tensor,
|
| 118 |
+
masks=original_mask_tensor,
|
| 119 |
+
alpha=.5,
|
| 120 |
+
colors='white')
|
| 121 |
+
|
| 122 |
+
if option == 'mask':
|
| 123 |
+
return original_mask_tensor.numpy()
|
| 124 |
+
if option == 'image_with_mask':
|
| 125 |
+
return image_with_mask.permute(1,2,0).numpy()
|
| 126 |
+
if option == 'mask_and_image_with_mask':
|
| 127 |
+
return original_mask_tensor.numpy(), image_with_mask.permute(1,2,0).numpy()
|
| 128 |
+
|
| 129 |
+
def print_summary(self) -> None:
|
| 130 |
+
raise NotImplementedError()
|
models/custom_unet.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eeca616e3026a77a2125e4c880f5335e1efa4a2c53b1ad4dad0082e227e49b85
|
| 3 |
+
size 7812958
|
models/unet.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d26354f766301bc980c66f3984599491a4c3dc35706dc3e31a95f115e30a74c6
|
| 3 |
+
size 25378610
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchinfo
|
| 3 |
+
segmentation-models-pytorch
|
| 4 |
+
albumentations
|
| 5 |
+
opencv-python
|
| 6 |
+
gradio
|
| 7 |
+
numpy
|
| 8 |
+
matplotlib
|
| 9 |
+
tensorboard
|
unet.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module defines a Unet architecture"""
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
from torchinfo import summary
|
| 6 |
+
import segmentation_models_pytorch as smp
|
| 7 |
+
from early_stopper import EarlyStopper
|
| 8 |
+
|
| 9 |
+
from model import SegmentationModel
|
| 10 |
+
|
| 11 |
+
from typing import Optional, Union, Tuple
|
| 12 |
+
|
| 13 |
+
class Unet(SegmentationModel):
|
| 14 |
+
|
| 15 |
+
def __init__(self,
|
| 16 |
+
name: str = 'default_name',
|
| 17 |
+
from_file: bool = True,
|
| 18 |
+
image_size: Tuple[int, int] = (320, 320),
|
| 19 |
+
encoder_name: str = 'timm-efficientnet-b0',
|
| 20 |
+
pretrained: bool = True,
|
| 21 |
+
in_channels: int = 3,
|
| 22 |
+
encoder_depth: int = 5,
|
| 23 |
+
device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> None:
|
| 24 |
+
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
self.name = name
|
| 28 |
+
self.image_size = image_size
|
| 29 |
+
self.in_channels = in_channels
|
| 30 |
+
self.device = device
|
| 31 |
+
|
| 32 |
+
self.save_path = f'models/{name}.pth'
|
| 33 |
+
|
| 34 |
+
if from_file:
|
| 35 |
+
self.unet = torch.load(self.save_path, map_location=device)
|
| 36 |
+
else:
|
| 37 |
+
self.unet = smp.Unet(
|
| 38 |
+
encoder_name=encoder_name,
|
| 39 |
+
encoder_weights='imagenet' if pretrained else None,
|
| 40 |
+
in_channels=in_channels,
|
| 41 |
+
encoder_depth=encoder_depth,
|
| 42 |
+
classes=1,
|
| 43 |
+
activation=None
|
| 44 |
+
).to(device)
|
| 45 |
+
|
| 46 |
+
bce_loss_fn = nn.BCEWithLogitsLoss()
|
| 47 |
+
dice_loss_fn = smp.losses.DiceLoss(mode='binary')
|
| 48 |
+
self.loss_fn = lambda logits, masks: bce_loss_fn(logits, masks) + dice_loss_fn(logits, masks)
|
| 49 |
+
|
| 50 |
+
def configure_optimizers(self, **kwargs):
|
| 51 |
+
self.optimizer = torch.optim.Adam(params=self.unet.parameters(), lr=kwargs['lr'])
|
| 52 |
+
self.early_stopper = EarlyStopper(patience=kwargs['patience'])
|
| 53 |
+
|
| 54 |
+
def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 55 |
+
logits = self.unet(images)
|
| 56 |
+
if masks is None:
|
| 57 |
+
return logits
|
| 58 |
+
return logits, self.loss_fn(logits, masks)
|
| 59 |
+
|
| 60 |
+
def save(self) -> None:
|
| 61 |
+
# Save the whole model, not only the state dict, so that it will work for different unets
|
| 62 |
+
torch.save(self.unet, self.save_path)
|
| 63 |
+
|
| 64 |
+
def print_summary(self, batch_size: int = 16) -> None:
|
| 65 |
+
print(summary(self.unet, input_size=(batch_size, self.in_channels, *self.image_size),
|
| 66 |
+
col_names=['input_size', 'output_size', 'num_params'],
|
| 67 |
+
row_settings=['var_names']))
|