i4ata commited on
Commit
671b160
·
1 Parent(s): 8b5523a

demo to spaces

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 i4ata
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,19 @@
1
- ---
2
- title: CustomUnetSegmentation
3
- emoji: 🐠
4
- colorFrom: gray
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.19.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
1
+ # UnetSegmentation
2
+
3
+ My own implementation of the U-net architecture compared to a pretrained model from PyTorch Segmentation Models.
4
+
5
+ To train:
6
+
7
+ ```python train.py```
8
+
9
+ To visualize training:
10
+
11
+ ```tesnorboard --logdir runs```
12
+
13
+ To visually compare models on some examples:
14
+
15
+ ```python compare_models.py```
16
+
17
+ To launch a Gradio application:
18
+
19
+ ```python3 gradio_app.py```
__pycache__/custom_unet.cpython-310.pyc ADDED
Binary file (5.68 kB). View file
 
__pycache__/early_stopper.cpython-310.pyc ADDED
Binary file (956 Bytes). View file
 
__pycache__/model.cpython-310.pyc ADDED
Binary file (4.59 kB). View file
 
__pycache__/unet.cpython-310.pyc ADDED
Binary file (2.81 kB). View file
 
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import os
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from model import SegmentationModel
9
+ from custom_unet import CustomUnet
10
+ from unet import Unet
11
+
12
+ from typing import Dict, Union, Tuple, List
13
+
14
+ class GradioApp:
15
+
16
+ def __init__(self) -> None:
17
+
18
+ self.models: Dict[str, Union[str, SegmentationModel]] = {
19
+ 'Custom': 'custom_unet',
20
+ 'Pretrained': 'unet'
21
+ }
22
+
23
+ def predict(self, img_file: str, model_name: str) -> Tuple[str, List[Tuple[np.ndarray, str]]]:
24
+
25
+ # Lazy loading of models
26
+ if isinstance(self.models[model_name], str):
27
+ model_class = CustomUnet if model_name == 'Custom' else Unet
28
+ self.models[model_name] = model_class(self.models[model_name], from_file=True, device='cpu')
29
+ self.models[model_name].eval()
30
+ prediction = self.models[model_name].predict(img_file, option='mask')[0] * 1
31
+ return img_file, [(prediction, 'person')]
32
+
33
+ def launch(self):
34
+
35
+ examples_list = [['examples/' + example] for example in os.listdir('examples')]
36
+
37
+ demo = gr.Interface(
38
+ fn=self.predict,
39
+ inputs=[
40
+ gr.Image(type='filepath', label='Input image to segment'),
41
+ gr.Radio(choices=('Custom', 'Pretrained'), label='Available models')
42
+ ],
43
+ outputs=gr.AnnotatedImage(label='Model predictions'),
44
+ examples=examples_list,
45
+ cache_examples=False,
46
+ #title='Plants Diseases Classification',
47
+ #description=f'This model performs classification on images of leaves that are either healthy, \
48
+ # have bean rust, or have an angular leaf spot. A vision transformer neural network architecture is used. \
49
+ # The dataset can be downloaded from [Kaggle]({dataset_url}) and the source code is on [GitHub]({github_repo_url}).',
50
+ )
51
+ demo.launch()
52
+
53
+ if __name__ == '__main__':
54
+ app = GradioApp()
55
+ app.launch()
custom_unet.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This python module impements the Unet architecture as defined in https://arxiv.org/pdf/1505.04597.
2
+ Only, I use padded convolutions. That way, there is no need for center cropping and the output mask
3
+ is the same shape as the input image.
4
+
5
+ Additional things: https://towardsdatascience.com/understanding-u-net-61276b10f360
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torchinfo import summary
11
+
12
+ from model import SegmentationModel
13
+ from early_stopper import EarlyStopper
14
+
15
+ from typing import Tuple, Union, Optional
16
+
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+
19
+ class DiceLoss(nn.Module):
20
+
21
+ def forward(self, logits: torch.Tensor, mask_true: torch.Tensor):
22
+ logits = torch.sigmoid(logits) > .5
23
+ intersection = (logits * mask_true).sum()
24
+ union = logits.sum() + mask_true.sum()
25
+ return 2 * intersection / union
26
+
27
+ class DoubleConv(nn.Module):
28
+
29
+ def __init__(self, in_channels: int, out_channels: int) -> None:
30
+
31
+ super().__init__()
32
+ self.relu = nn.ReLU()
33
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
34
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')
35
+
36
+ def forward(self, x: torch.Tensor):
37
+ return self.relu(self.conv2(self.relu(self.conv1(x))))
38
+
39
+ class Up(nn.Module):
40
+
41
+ def __init__(self, in_channels, out_channels) -> None:
42
+ super().__init__()
43
+ self.upconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
44
+ self.conv = DoubleConv(in_channels=in_channels, out_channels=out_channels)
45
+
46
+ def forward(self, x_left, x_right):
47
+ return self.conv(torch.cat((x_left, self.upconv(x_right)), dim=1))
48
+
49
+ class UnetModel(nn.Module):
50
+
51
+ def __init__(self, in_channels: int = 3, depth: int = 3, start_channels: int = 16) -> None:
52
+
53
+ super().__init__()
54
+
55
+ self.input_conv = DoubleConv(in_channels, start_channels)
56
+
57
+ self.encoder_layers = nn.ModuleList()
58
+ for i in range(depth):
59
+ self.encoder_layers.append(DoubleConv(start_channels, start_channels * 2))
60
+ start_channels *= 2
61
+
62
+ self.decoder_layers = nn.ModuleList()
63
+ for i in range(depth):
64
+ self.decoder_layers.append(Up(start_channels, start_channels // 2))
65
+ start_channels //= 2
66
+
67
+ self.output_conv = nn.Conv2d(start_channels, 1, kernel_size=1)
68
+
69
+ self.pool = nn.MaxPool2d(2, 2)
70
+
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+
73
+ x = self.input_conv(x)
74
+ xs = [x]
75
+
76
+ for encoding_layer in self.encoder_layers:
77
+ x = encoding_layer(self.pool(x))
78
+ xs.append(x)
79
+
80
+ for decoding_layer, x_left in zip(self.decoder_layers, reversed(xs[:-1])):
81
+ x = decoding_layer(x_left, x)
82
+
83
+ return self.output_conv(x)
84
+
85
+ class CustomUnet(SegmentationModel):
86
+
87
+ def __init__(self,
88
+ name: str = 'default_name',
89
+ from_file: bool = True,
90
+ image_size: Tuple[int, int] = (320, 320),
91
+ in_channels: int = 3,
92
+ start_channels: int = 16,
93
+ encoder_depth: int = 5,
94
+ device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> None:
95
+
96
+ super().__init__()
97
+
98
+ assert image_size[0] % (2**encoder_depth) == 0
99
+ assert image_size[1] % (2**encoder_depth) == 0
100
+
101
+ self.name = name
102
+ self.image_size = image_size
103
+ self.in_channels = in_channels
104
+ self.device = device
105
+
106
+ self.save_path = f'models/{name}.pth'
107
+
108
+ if from_file:
109
+ self.unet = torch.load(self.save_path, map_location=device)
110
+ else:
111
+ self.unet = UnetModel(in_channels=in_channels, depth=encoder_depth, start_channels=start_channels).to(device)
112
+
113
+ self.bce_loss = nn.BCEWithLogitsLoss()
114
+ self.dice_loss = DiceLoss()
115
+ self.loss_fn = lambda logits, masks: self.bce_loss(logits, masks) + self.dice_loss(logits, masks)
116
+
117
+ def configure_optimizers(self, **kwargs):
118
+ self.optimizer = torch.optim.Adam(params=self.unet.parameters(), lr=kwargs['lr'])
119
+ self.early_stopper = EarlyStopper(patience=kwargs['patience'])
120
+
121
+ def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122
+ logits = self.unet(images)
123
+ if masks is None:
124
+ return logits
125
+ return logits, self.loss_fn(logits, masks)
126
+
127
+ def save(self) -> None:
128
+ # Save the whole model, not only the state dict, so that it will work for different unets
129
+ torch.save(self.unet, self.save_path)
130
+
131
+ def print_summary(self, batch_size: int = 16) -> None:
132
+
133
+ print(summary(self.unet, input_size=(batch_size, self.in_channels, *self.image_size),
134
+ col_names=['input_size', 'output_size', 'num_params'],
135
+ row_settings=['var_names']))
early_stopper.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains a class that implements early stopping regularization technique"""
2
+
3
+ class EarlyStopper:
4
+
5
+ def __init__(self, patience: int = 2):
6
+
7
+ self.patience = patience
8
+ self.best_loss = float('inf')
9
+ self.counter = 0
10
+ self.save_model = False
11
+
12
+ def check(self, validation_loss: float) -> bool:
13
+
14
+ self.save_model = False
15
+ if validation_loss > self.best_loss:
16
+ self.counter += 1
17
+ if self.counter == self.patience:
18
+ return True
19
+ else:
20
+ self.best_loss = validation_loss
21
+ self.counter = 0
22
+ self.save_model = True
23
+ return False
examples/example_1.jpg ADDED
examples/example_2.jpg ADDED
examples/example_3.jpg ADDED
model.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains the base class for segmentation models"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import DataLoader
6
+ from torchvision.utils import draw_segmentation_masks
7
+ from torchvision.transforms.functional import resize
8
+ from torch.utils.tensorboard.writer import SummaryWriter
9
+
10
+ import numpy as np
11
+ import cv2 as cv
12
+ import albumentations as A
13
+
14
+ from typing import Optional, Union, Tuple, Literal
15
+
16
+ from early_stopper import EarlyStopper
17
+
18
+ class SegmentationModel(nn.Module):
19
+
20
+ name: str = "base name"
21
+ device: Literal['cpu', 'cuda'] = None
22
+
23
+ optimizer: torch.optim.Optimizer = None
24
+ early_stopper: EarlyStopper = None
25
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None
26
+ save_path: str = None
27
+ image_size: Tuple[int, int] = None
28
+
29
+ def configure_optimizers(self, **kwargs) -> None:
30
+ raise NotImplementedError()
31
+
32
+ def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
33
+
34
+ raise NotImplementedError()
35
+
36
+ def _train_step(self, data_loader: DataLoader) -> float:
37
+
38
+ self.train()
39
+ total_loss = 0.
40
+ for images, masks in data_loader:
41
+ images, masks = images.to(self.device), masks.to(self.device)
42
+
43
+ self.optimizer.zero_grad()
44
+ logits, loss = self(images, masks)
45
+ loss.backward()
46
+ self.optimizer.step()
47
+
48
+ total_loss += loss.item()
49
+
50
+ return total_loss / len(data_loader)
51
+
52
+ def _test_step(self, data_loader: DataLoader) -> float:
53
+
54
+ self.eval()
55
+ total_loss = 0.
56
+ with torch.inference_mode():
57
+ for images, masks in data_loader:
58
+ images, masks = images.to(self.device), masks.to(self.device)
59
+ logits, loss = self(images, masks)
60
+ total_loss += loss.item()
61
+ return total_loss / len(data_loader)
62
+
63
+ def train_model(self, train_loader: DataLoader, test_loader: DataLoader, epochs: int, log_dir: str) -> None:
64
+
65
+ writer = SummaryWriter(log_dir=f'{log_dir}/{self.name}')
66
+
67
+ for i in range(epochs):
68
+ train_loss = self._train_step(train_loader)
69
+ test_loss = self._test_step(test_loader)
70
+
71
+ if self.early_stopper is not None:
72
+ if self.early_stopper.check(test_loss):
73
+ print(f'Model stopped early due to risk of overfitting')
74
+ break
75
+
76
+ if self.early_stopper.save_model:
77
+ self.save()
78
+ print('saved model')
79
+
80
+ if self.lr_scheduler is not None:
81
+ self.lr_scheduler.step()
82
+
83
+ print(f'{i}: Train loss: {train_loss :.2} | Test loss: {test_loss :.2}')
84
+
85
+ writer.add_scalars(main_tag='Loss over time',
86
+ tag_scalar_dict={'train loss': train_loss, 'test loss': test_loss},
87
+ global_step=i)
88
+
89
+ else:
90
+ if self.early_stopper is not None:
91
+ print('Model did not converge. Possibility of underfitting')
92
+ self.save()
93
+ writer.close()
94
+
95
+ def save(self) -> None:
96
+ raise NotImplementedError()
97
+
98
+ def predict(self,
99
+ test_image_path: str,
100
+ option: Literal['mask', 'image_with_mask', 'mask_and_image_with_mask'] = 'image_with_mask'
101
+ ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
102
+
103
+ self.eval()
104
+ input_resizer = A.Resize(*self.image_size)
105
+
106
+ original_image = cv.cvtColor(cv.imread(test_image_path), cv.COLOR_BGR2RGB)
107
+ original_image_tensor = torch.from_numpy(original_image).permute(2,0,1).type(torch.uint8)
108
+ resized_image_tensor = (torch.from_numpy(input_resizer(image=original_image)['image']).float() / 255.).permute(2,0,1)
109
+
110
+ with torch.inference_mode():
111
+ logits = self(resized_image_tensor.unsqueeze(0).to(self.device)).squeeze(0).cpu().detach()
112
+ probs = torch.sigmoid(logits)
113
+ resized_mask_tensor = probs > .5
114
+
115
+ original_mask_tensor = resize(resized_mask_tensor, size=original_image.shape[:-1], antialias=True)
116
+
117
+ image_with_mask = draw_segmentation_masks(image=original_image_tensor,
118
+ masks=original_mask_tensor,
119
+ alpha=.5,
120
+ colors='white')
121
+
122
+ if option == 'mask':
123
+ return original_mask_tensor.numpy()
124
+ if option == 'image_with_mask':
125
+ return image_with_mask.permute(1,2,0).numpy()
126
+ if option == 'mask_and_image_with_mask':
127
+ return original_mask_tensor.numpy(), image_with_mask.permute(1,2,0).numpy()
128
+
129
+ def print_summary(self) -> None:
130
+ raise NotImplementedError()
models/custom_unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeca616e3026a77a2125e4c880f5335e1efa4a2c53b1ad4dad0082e227e49b85
3
+ size 7812958
models/unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d26354f766301bc980c66f3984599491a4c3dc35706dc3e31a95f115e30a74c6
3
+ size 25378610
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchinfo
3
+ segmentation-models-pytorch
4
+ albumentations
5
+ opencv-python
6
+ gradio
7
+ numpy
8
+ matplotlib
9
+ tensorboard
unet.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module defines a Unet architecture"""
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+ from torchinfo import summary
6
+ import segmentation_models_pytorch as smp
7
+ from early_stopper import EarlyStopper
8
+
9
+ from model import SegmentationModel
10
+
11
+ from typing import Optional, Union, Tuple
12
+
13
+ class Unet(SegmentationModel):
14
+
15
+ def __init__(self,
16
+ name: str = 'default_name',
17
+ from_file: bool = True,
18
+ image_size: Tuple[int, int] = (320, 320),
19
+ encoder_name: str = 'timm-efficientnet-b0',
20
+ pretrained: bool = True,
21
+ in_channels: int = 3,
22
+ encoder_depth: int = 5,
23
+ device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> None:
24
+
25
+ super().__init__()
26
+
27
+ self.name = name
28
+ self.image_size = image_size
29
+ self.in_channels = in_channels
30
+ self.device = device
31
+
32
+ self.save_path = f'models/{name}.pth'
33
+
34
+ if from_file:
35
+ self.unet = torch.load(self.save_path, map_location=device)
36
+ else:
37
+ self.unet = smp.Unet(
38
+ encoder_name=encoder_name,
39
+ encoder_weights='imagenet' if pretrained else None,
40
+ in_channels=in_channels,
41
+ encoder_depth=encoder_depth,
42
+ classes=1,
43
+ activation=None
44
+ ).to(device)
45
+
46
+ bce_loss_fn = nn.BCEWithLogitsLoss()
47
+ dice_loss_fn = smp.losses.DiceLoss(mode='binary')
48
+ self.loss_fn = lambda logits, masks: bce_loss_fn(logits, masks) + dice_loss_fn(logits, masks)
49
+
50
+ def configure_optimizers(self, **kwargs):
51
+ self.optimizer = torch.optim.Adam(params=self.unet.parameters(), lr=kwargs['lr'])
52
+ self.early_stopper = EarlyStopper(patience=kwargs['patience'])
53
+
54
+ def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
55
+ logits = self.unet(images)
56
+ if masks is None:
57
+ return logits
58
+ return logits, self.loss_fn(logits, masks)
59
+
60
+ def save(self) -> None:
61
+ # Save the whole model, not only the state dict, so that it will work for different unets
62
+ torch.save(self.unet, self.save_path)
63
+
64
+ def print_summary(self, batch_size: int = 16) -> None:
65
+ print(summary(self.unet, input_size=(batch_size, self.in_channels, *self.image_size),
66
+ col_names=['input_size', 'output_size', 'num_params'],
67
+ row_settings=['var_names']))