i4ata commited on
Commit
4fcc913
·
1 Parent(s): 7e91fbc
app.py CHANGED
@@ -1,36 +1,44 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import os
4
-
5
  import torch
6
  import numpy as np
 
 
 
7
 
8
- from model import SegmentationModel
9
  from custom_unet import CustomUnet
10
- from unet import Unet
11
 
12
- from typing import Dict, Union, Tuple, List
13
 
14
  class GradioApp:
15
 
16
  def __init__(self) -> None:
17
 
18
- self.models: Dict[str, Union[str, SegmentationModel]] = {
19
- 'Custom': 'custom_unet',
20
- 'Pretrained': 'unet'
 
 
 
 
 
 
 
 
21
  }
22
 
23
  def predict(self, img_file: str, model_name: str) -> Tuple[str, List[Tuple[np.ndarray, str]]]:
24
 
25
- # Lazy loading of models
26
- if isinstance(self.models[model_name], str):
27
- model_class = CustomUnet if model_name == 'Custom' else Unet
28
- self.models[model_name] = model_class(self.models[model_name], from_file=True, device='cpu')
29
- self.models[model_name].eval()
30
-
31
- prediction = self.models[model_name].predict(img_file, option='mask')[0] * 1
32
- return img_file, [(prediction, 'person')]
33
-
34
  def launch(self):
35
 
36
  examples_list = [['examples/' + example] for example in os.listdir('examples')]
 
1
  import gradio as gr
2
  from PIL import Image
3
  import os
 
4
  import torch
5
  import numpy as np
6
+ import torchvision.transforms as transforms
7
+ from torchvision.transforms.functional import resize
8
+ from typing import Tuple, List
9
 
 
10
  from custom_unet import CustomUnet
11
+ from utils import val_transform, get_pretrained_unet
12
 
 
13
 
14
  class GradioApp:
15
 
16
  def __init__(self) -> None:
17
 
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+
20
+ custom = CustomUnet().to(device).eval()
21
+ custom.load_state_dict(torch.load('models/custom_unet.pt', map_location=device))
22
+
23
+ pretrained = get_pretrained_unet().to(device).eval()
24
+ pretrained.load_state_dict(torch.load('models/pretrained_unet.pt', map_location=device))
25
+
26
+ self.models = {
27
+ 'Custom': custom,
28
+ 'Pretrained': pretrained
29
  }
30
 
31
  def predict(self, img_file: str, model_name: str) -> Tuple[str, List[Tuple[np.ndarray, str]]]:
32
 
33
+ image = image=np.asarray(Image.open(img_file))
34
+ h,w = image.shape[:-1]
35
+ image = torch.from_numpy(val_transform(image=image)['image']).float().permute(2,0,1) / 255.
36
+ with torch.inference_mode():
37
+ prediction = self.models[model_name](image.to(self.device).unsqueeze(0))[0].sigmoid().round().cpu()
38
+ mask = resize(img=prediction, size=(h,w), interpolation=transforms.InterpolationMode.NEAREST)[0].numpy()
39
+
40
+ return img_file, [(mask, 'person')]
41
+
42
  def launch(self):
43
 
44
  examples_list = [['examples/' + example] for example in os.listdir('examples')]
custom_unet.py CHANGED
@@ -7,46 +7,30 @@ Additional things: https://towardsdatascience.com/understanding-u-net-61276b10f3
7
 
8
  import torch
9
  import torch.nn as nn
10
- from torchinfo import summary
11
-
12
- from model import SegmentationModel
13
- from early_stopper import EarlyStopper
14
-
15
- from typing import Tuple, Union, Optional
16
-
17
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
-
19
- class DiceLoss(nn.Module):
20
-
21
- def forward(self, logits: torch.Tensor, mask_true: torch.Tensor):
22
- logits = torch.sigmoid(logits) > .5
23
- intersection = (logits * mask_true).sum()
24
- union = logits.sum() + mask_true.sum()
25
- return 2 * intersection / union
26
 
27
  class DoubleConv(nn.Module):
28
 
29
  def __init__(self, in_channels: int, out_channels: int) -> None:
30
 
31
  super().__init__()
32
- self.relu = nn.ReLU()
33
  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
34
  self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')
35
 
36
  def forward(self, x: torch.Tensor):
37
- return self.relu(self.conv2(self.relu(self.conv1(x))))
38
 
39
  class Up(nn.Module):
40
 
41
- def __init__(self, in_channels, out_channels) -> None:
42
  super().__init__()
43
  self.upconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
44
  self.conv = DoubleConv(in_channels=in_channels, out_channels=out_channels)
45
 
46
- def forward(self, x_left, x_right):
47
  return self.conv(torch.cat((x_left, self.upconv(x_right)), dim=1))
48
 
49
- class UnetModel(nn.Module):
50
 
51
  def __init__(self, in_channels: int = 3, depth: int = 3, start_channels: int = 16) -> None:
52
 
@@ -65,8 +49,6 @@ class UnetModel(nn.Module):
65
  start_channels //= 2
66
 
67
  self.output_conv = nn.Conv2d(start_channels, 1, kernel_size=1)
68
-
69
- self.pool = nn.MaxPool2d(2, 2)
70
 
71
  def forward(self, x: torch.Tensor) -> torch.Tensor:
72
 
@@ -74,62 +56,10 @@ class UnetModel(nn.Module):
74
  xs = [x]
75
 
76
  for encoding_layer in self.encoder_layers:
77
- x = encoding_layer(self.pool(x))
78
  xs.append(x)
79
 
80
  for decoding_layer, x_left in zip(self.decoder_layers, reversed(xs[:-1])):
81
  x = decoding_layer(x_left, x)
82
 
83
  return self.output_conv(x)
84
-
85
- class CustomUnet(SegmentationModel):
86
-
87
- def __init__(self,
88
- name: str = 'default_name',
89
- from_file: bool = True,
90
- image_size: Tuple[int, int] = (320, 320),
91
- in_channels: int = 3,
92
- start_channels: int = 16,
93
- encoder_depth: int = 5,
94
- device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> None:
95
-
96
- super().__init__()
97
-
98
- assert image_size[0] % (2**encoder_depth) == 0
99
- assert image_size[1] % (2**encoder_depth) == 0
100
-
101
- self.name = name
102
- self.image_size = image_size
103
- self.in_channels = in_channels
104
- self.device = device
105
-
106
- self.save_path = f'models/{name}.pth'
107
-
108
- if from_file:
109
- self.unet = torch.load(self.save_path, map_location=device)
110
- else:
111
- self.unet = UnetModel(in_channels=in_channels, depth=encoder_depth, start_channels=start_channels).to(device)
112
-
113
- self.bce_loss = nn.BCEWithLogitsLoss()
114
- self.dice_loss = DiceLoss()
115
- self.loss_fn = lambda logits, masks: self.bce_loss(logits, masks) + self.dice_loss(logits, masks)
116
-
117
- def configure_optimizers(self, **kwargs):
118
- self.optimizer = torch.optim.Adam(params=self.unet.parameters(), lr=kwargs['lr'])
119
- self.early_stopper = EarlyStopper(patience=kwargs['patience'])
120
-
121
- def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122
- logits = self.unet(images)
123
- if masks is None:
124
- return logits
125
- return logits, self.loss_fn(logits, masks)
126
-
127
- def save(self) -> None:
128
- # Save the whole model, not only the state dict, so that it will work for different unets
129
- torch.save(self.unet, self.save_path)
130
-
131
- def print_summary(self, batch_size: int = 16) -> None:
132
-
133
- print(summary(self.unet, input_size=(batch_size, self.in_channels, *self.image_size),
134
- col_names=['input_size', 'output_size', 'num_params'],
135
- row_settings=['var_names']))
 
7
 
8
  import torch
9
  import torch.nn as nn
10
+ import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class DoubleConv(nn.Module):
13
 
14
  def __init__(self, in_channels: int, out_channels: int) -> None:
15
 
16
  super().__init__()
 
17
  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
18
  self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')
19
 
20
  def forward(self, x: torch.Tensor):
21
+ return F.relu(self.conv2(F.relu(self.conv1(x))))
22
 
23
  class Up(nn.Module):
24
 
25
+ def __init__(self, in_channels: int, out_channels: int) -> None:
26
  super().__init__()
27
  self.upconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
28
  self.conv = DoubleConv(in_channels=in_channels, out_channels=out_channels)
29
 
30
+ def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> torch.Tensor:
31
  return self.conv(torch.cat((x_left, self.upconv(x_right)), dim=1))
32
 
33
+ class CustomUnet(nn.Module):
34
 
35
  def __init__(self, in_channels: int = 3, depth: int = 3, start_channels: int = 16) -> None:
36
 
 
49
  start_channels //= 2
50
 
51
  self.output_conv = nn.Conv2d(start_channels, 1, kernel_size=1)
 
 
52
 
53
  def forward(self, x: torch.Tensor) -> torch.Tensor:
54
 
 
56
  xs = [x]
57
 
58
  for encoding_layer in self.encoder_layers:
59
+ x = encoding_layer(F.max_pool2d(x, 2))
60
  xs.append(x)
61
 
62
  for decoding_layer, x_left in zip(self.decoder_layers, reversed(xs[:-1])):
63
  x = decoding_layer(x_left, x)
64
 
65
  return self.output_conv(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
early_stopper.py DELETED
@@ -1,23 +0,0 @@
1
- """This module contains a class that implements early stopping regularization technique"""
2
-
3
- class EarlyStopper:
4
-
5
- def __init__(self, patience: int = 2):
6
-
7
- self.patience = patience
8
- self.best_loss = float('inf')
9
- self.counter = 0
10
- self.save_model = False
11
-
12
- def check(self, validation_loss: float) -> bool:
13
-
14
- self.save_model = False
15
- if validation_loss > self.best_loss:
16
- self.counter += 1
17
- if self.counter == self.patience:
18
- return True
19
- else:
20
- self.best_loss = validation_loss
21
- self.counter = 0
22
- self.save_model = True
23
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py DELETED
@@ -1,130 +0,0 @@
1
- """This module contains the base class for segmentation models"""
2
-
3
- import torch
4
- import torch.nn as nn
5
- from torch.utils.data import DataLoader
6
- from torchvision.utils import draw_segmentation_masks
7
- from torchvision.transforms.functional import resize
8
- from torch.utils.tensorboard.writer import SummaryWriter
9
-
10
- import numpy as np
11
- import cv2 as cv
12
- import albumentations as A
13
-
14
- from typing import Optional, Union, Tuple, Literal
15
-
16
- from early_stopper import EarlyStopper
17
-
18
- class SegmentationModel(nn.Module):
19
-
20
- name: str = "base name"
21
- device: Literal['cpu', 'cuda'] = None
22
-
23
- optimizer: torch.optim.Optimizer = None
24
- early_stopper: EarlyStopper = None
25
- lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None
26
- save_path: str = None
27
- image_size: Tuple[int, int] = None
28
-
29
- def configure_optimizers(self, **kwargs) -> None:
30
- raise NotImplementedError()
31
-
32
- def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
33
-
34
- raise NotImplementedError()
35
-
36
- def _train_step(self, data_loader: DataLoader) -> float:
37
-
38
- self.train()
39
- total_loss = 0.
40
- for images, masks in data_loader:
41
- images, masks = images.to(self.device), masks.to(self.device)
42
-
43
- self.optimizer.zero_grad()
44
- logits, loss = self(images, masks)
45
- loss.backward()
46
- self.optimizer.step()
47
-
48
- total_loss += loss.item()
49
-
50
- return total_loss / len(data_loader)
51
-
52
- def _test_step(self, data_loader: DataLoader) -> float:
53
-
54
- self.eval()
55
- total_loss = 0.
56
- with torch.inference_mode():
57
- for images, masks in data_loader:
58
- images, masks = images.to(self.device), masks.to(self.device)
59
- logits, loss = self(images, masks)
60
- total_loss += loss.item()
61
- return total_loss / len(data_loader)
62
-
63
- def train_model(self, train_loader: DataLoader, test_loader: DataLoader, epochs: int, log_dir: str) -> None:
64
-
65
- writer = SummaryWriter(log_dir=f'{log_dir}/{self.name}')
66
-
67
- for i in range(epochs):
68
- train_loss = self._train_step(train_loader)
69
- test_loss = self._test_step(test_loader)
70
-
71
- if self.early_stopper is not None:
72
- if self.early_stopper.check(test_loss):
73
- print(f'Model stopped early due to risk of overfitting')
74
- break
75
-
76
- if self.early_stopper.save_model:
77
- self.save()
78
- print('saved model')
79
-
80
- if self.lr_scheduler is not None:
81
- self.lr_scheduler.step()
82
-
83
- print(f'{i}: Train loss: {train_loss :.2} | Test loss: {test_loss :.2}')
84
-
85
- writer.add_scalars(main_tag='Loss over time',
86
- tag_scalar_dict={'train loss': train_loss, 'test loss': test_loss},
87
- global_step=i)
88
-
89
- else:
90
- if self.early_stopper is not None:
91
- print('Model did not converge. Possibility of underfitting')
92
- self.save()
93
- writer.close()
94
-
95
- def save(self) -> None:
96
- raise NotImplementedError()
97
-
98
- def predict(self,
99
- test_image_path: str,
100
- option: Literal['mask', 'image_with_mask', 'mask_and_image_with_mask'] = 'image_with_mask'
101
- ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
102
-
103
- self.eval()
104
- input_resizer = A.Resize(*self.image_size)
105
-
106
- original_image = cv.cvtColor(cv.imread(test_image_path), cv.COLOR_BGR2RGB)
107
- original_image_tensor = torch.from_numpy(original_image).permute(2,0,1).type(torch.uint8)
108
- resized_image_tensor = (torch.from_numpy(input_resizer(image=original_image)['image']).float() / 255.).permute(2,0,1)
109
-
110
- with torch.inference_mode():
111
- logits = self(resized_image_tensor.unsqueeze(0).to(self.device)).squeeze(0).cpu().detach()
112
- probs = torch.sigmoid(logits)
113
- resized_mask_tensor = probs > .5
114
-
115
- original_mask_tensor = resize(resized_mask_tensor, size=original_image.shape[:-1], antialias=True)
116
-
117
- image_with_mask = draw_segmentation_masks(image=original_image_tensor,
118
- masks=original_mask_tensor,
119
- alpha=.5,
120
- colors='white')
121
-
122
- if option == 'mask':
123
- return original_mask_tensor.numpy()
124
- if option == 'image_with_mask':
125
- return image_with_mask.permute(1,2,0).numpy()
126
- if option == 'mask_and_image_with_mask':
127
- return original_mask_tensor.numpy(), image_with_mask.permute(1,2,0).numpy()
128
-
129
- def print_summary(self) -> None:
130
- raise NotImplementedError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/{custom_unet.pth → custom_unet.pt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:eeca616e3026a77a2125e4c880f5335e1efa4a2c53b1ad4dad0082e227e49b85
3
- size 7812958
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a227b749031ac42b97c9833bd18e8c37b5104bb94546f2063310a73a9b912fe5
3
+ size 1941386
models/{unet.pth → pretrained_unet.pt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d26354f766301bc980c66f3984599491a4c3dc35706dc3e31a95f115e30a74c6
3
- size 25378610
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b3fe6a191c10e927901529593049ba76ccbab708ec62a8d25fa2f8b46cb4ac2
3
+ size 25339050
requirements.txt CHANGED
@@ -1,9 +1,8 @@
1
  torch
2
- torchinfo
3
  segmentation-models-pytorch
4
  albumentations
5
  opencv-python
6
  gradio
7
  numpy
8
  matplotlib
9
- tensorboard
 
1
  torch
 
2
  segmentation-models-pytorch
3
  albumentations
4
  opencv-python
5
  gradio
6
  numpy
7
  matplotlib
8
+ tensorboard
unet.py DELETED
@@ -1,67 +0,0 @@
1
- """This module defines a Unet architecture"""
2
-
3
- import torch.nn as nn
4
- import torch
5
- from torchinfo import summary
6
- import segmentation_models_pytorch as smp
7
- from early_stopper import EarlyStopper
8
-
9
- from model import SegmentationModel
10
-
11
- from typing import Optional, Union, Tuple
12
-
13
- class Unet(SegmentationModel):
14
-
15
- def __init__(self,
16
- name: str = 'default_name',
17
- from_file: bool = True,
18
- image_size: Tuple[int, int] = (320, 320),
19
- encoder_name: str = 'timm-efficientnet-b0',
20
- pretrained: bool = True,
21
- in_channels: int = 3,
22
- encoder_depth: int = 5,
23
- device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> None:
24
-
25
- super().__init__()
26
-
27
- self.name = name
28
- self.image_size = image_size
29
- self.in_channels = in_channels
30
- self.device = device
31
-
32
- self.save_path = f'models/{name}.pth'
33
-
34
- if from_file:
35
- self.unet = torch.load(self.save_path, map_location=device)
36
- else:
37
- self.unet = smp.Unet(
38
- encoder_name=encoder_name,
39
- encoder_weights='imagenet' if pretrained else None,
40
- in_channels=in_channels,
41
- encoder_depth=encoder_depth,
42
- classes=1,
43
- activation=None
44
- ).to(device)
45
-
46
- bce_loss_fn = nn.BCEWithLogitsLoss()
47
- dice_loss_fn = smp.losses.DiceLoss(mode='binary')
48
- self.loss_fn = lambda logits, masks: bce_loss_fn(logits, masks) + dice_loss_fn(logits, masks)
49
-
50
- def configure_optimizers(self, **kwargs):
51
- self.optimizer = torch.optim.Adam(params=self.unet.parameters(), lr=kwargs['lr'])
52
- self.early_stopper = EarlyStopper(patience=kwargs['patience'])
53
-
54
- def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
55
- logits = self.unet(images)
56
- if masks is None:
57
- return logits
58
- return logits, self.loss_fn(logits, masks)
59
-
60
- def save(self) -> None:
61
- # Save the whole model, not only the state dict, so that it will work for different unets
62
- torch.save(self.unet, self.save_path)
63
-
64
- def print_summary(self, batch_size: int = 16) -> None:
65
- print(summary(self.unet, input_size=(batch_size, self.in_channels, *self.image_size),
66
- col_names=['input_size', 'output_size', 'num_params'],
67
- row_settings=['var_names']))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ from segmentation_models_pytorch import Unet
3
+
4
+ val_transform = A.Compose(
5
+ transforms=[
6
+ A.Resize(320, 320)
7
+ ],
8
+ is_check_shapes=False
9
+ )
10
+
11
+ def get_pretrained_unet() -> Unet:
12
+ unet = Unet(
13
+ encoder_name='timm-efficientnet-b0',
14
+ encoder_weights='imagenet',
15
+ in_channels=3,
16
+ encoder_depth=5,
17
+ classes=1,
18
+ activation=None
19
+ )
20
+ return unet