i4ata commited on
Commit
ed1f711
·
1 Parent(s): fb9b166
__pycache__/custom_unet.cpython-310.pyc DELETED
Binary file (5.68 kB)
 
__pycache__/early_stopper.cpython-310.pyc DELETED
Binary file (956 Bytes)
 
__pycache__/model.cpython-310.pyc DELETED
Binary file (4.59 kB)
 
__pycache__/unet.cpython-310.pyc DELETED
Binary file (2.81 kB)
 
app.py CHANGED
@@ -1,15 +1,16 @@
1
  import gradio as gr
2
  from PIL import Image
3
- import os
4
  import torch
5
  import numpy as np
6
  import torchvision.transforms as transforms
7
  from torchvision.transforms.functional import resize
 
 
8
  from typing import Tuple, List
 
 
9
 
10
  from custom_unet import CustomUnet
11
- from utils import val_transform, get_pretrained_unet
12
-
13
 
14
  class GradioApp:
15
 
@@ -17,45 +18,39 @@ class GradioApp:
17
 
18
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
 
20
- custom = CustomUnet().to(self.device).eval()
21
- custom.load_state_dict(torch.load('models/custom_unet.pt', map_location=self.device))
22
 
23
- pretrained = get_pretrained_unet().to(self.device).eval()
24
- pretrained.load_state_dict(torch.load('models/pretrained_unet.pt', map_location=self.device))
25
 
26
- self.models = {
27
- 'Custom': custom,
28
- 'Pretrained': pretrained
29
- }
30
 
31
  def predict(self, img_file: str, model_name: str) -> Tuple[str, List[Tuple[np.ndarray, str]]]:
32
 
33
- image = image=np.asarray(Image.open(img_file))
34
  h,w = image.shape[:-1]
35
- image = torch.from_numpy(val_transform(image=image)['image']).float().permute(2,0,1) / 255.
36
  with torch.inference_mode():
37
  prediction = self.models[model_name](image.to(self.device).unsqueeze(0))[0].sigmoid().round().cpu()
38
  mask = resize(img=prediction, size=(h,w), interpolation=transforms.InterpolationMode.NEAREST)[0].numpy()
39
-
40
  return img_file, [(mask, 'person')]
41
 
42
  def launch(self):
43
-
44
- examples_list = [['examples/' + example] for example in os.listdir('examples')]
45
-
46
  demo = gr.Interface(
47
  fn=self.predict,
48
  inputs=[
49
  gr.Image(type='filepath', label='Input image to segment'),
50
- gr.Radio(choices=('Custom', 'Pretrained'), label='Available models')
51
  ],
52
  outputs=gr.AnnotatedImage(label='Model predictions'),
53
- examples=examples_list,
54
  cache_examples=False,
55
  title='Person Segmentation',
56
  description=f'This model performs segmentation on people in images. A Unet neural network architecture is used. \
57
  The dataset can be found [here](https://github.com/VikramShenoy97/Human-Segmentation-Dataset) \
58
- and the source code is on [GitHub](https://github.com/i4ata/UnetSegmentation).',
59
  )
60
  demo.launch()
61
 
 
1
  import gradio as gr
2
  from PIL import Image
 
3
  import torch
4
  import numpy as np
5
  import torchvision.transforms as transforms
6
  from torchvision.transforms.functional import resize
7
+ import albumentations as A
8
+ from segmentation_models_pytorch import Unet
9
  from typing import Tuple, List
10
+ import os
11
+ from glob import glob
12
 
13
  from custom_unet import CustomUnet
 
 
14
 
15
  class GradioApp:
16
 
 
18
 
19
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
 
21
+ custom = CustomUnet(in_channels=3, depth=3, start_channels=16).to(self.device).eval()
22
+ custom.load_state_dict(torch.load(os.path.join('models', 'custom_unet.pt'), map_location=self.device, weights_only=False))
23
 
24
+ pretrained = Unet(encoder_name='timm-efficientnet-b0', in_channels=3, encoder_depth=5, classes=1).to(self.device).eval()
25
+ pretrained.load_state_dict(torch.load(os.path.join('models', 'pretrained_unet.pt'), map_location=self.device, weights_only=False))
26
 
27
+ self.models = {'Custom': custom, 'Pretrained': pretrained}
28
+ self.transform = A.Compose(transforms=[A.Resize(320, 320)])
 
 
29
 
30
  def predict(self, img_file: str, model_name: str) -> Tuple[str, List[Tuple[np.ndarray, str]]]:
31
 
32
+ image = np.asarray(Image.open(img_file))
33
  h,w = image.shape[:-1]
34
+ image = torch.from_numpy(self.transform(image=image)['image']).float().permute(2,0,1) / 255.
35
  with torch.inference_mode():
36
  prediction = self.models[model_name](image.to(self.device).unsqueeze(0))[0].sigmoid().round().cpu()
37
  mask = resize(img=prediction, size=(h,w), interpolation=transforms.InterpolationMode.NEAREST)[0].numpy()
 
38
  return img_file, [(mask, 'person')]
39
 
40
  def launch(self):
 
 
 
41
  demo = gr.Interface(
42
  fn=self.predict,
43
  inputs=[
44
  gr.Image(type='filepath', label='Input image to segment'),
45
+ gr.Radio(choices=('Custom', 'Pretrained'), label='Available models', value='Custom')
46
  ],
47
  outputs=gr.AnnotatedImage(label='Model predictions'),
48
+ examples=[[example_path] for example_path in glob('examples/*.jpg')],
49
  cache_examples=False,
50
  title='Person Segmentation',
51
  description=f'This model performs segmentation on people in images. A Unet neural network architecture is used. \
52
  The dataset can be found [here](https://github.com/VikramShenoy97/Human-Segmentation-Dataset) \
53
+ and the source code is on [GitHub](https://github.com/i4ata/UnetSegmentation).'
54
  )
55
  demo.launch()
56
 
custom_unet.py CHANGED
@@ -1,18 +1,10 @@
1
- """This python module impements the Unet architecture as defined in https://arxiv.org/pdf/1505.04597.
2
- Only, I use padded convolutions. That way, there is no need for center cropping and the output mask
3
- is the same shape as the input image.
4
-
5
- Additional things: https://towardsdatascience.com/understanding-u-net-61276b10f360
6
- """
7
-
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
 
12
  class DoubleConv(nn.Module):
13
-
14
  def __init__(self, in_channels: int, out_channels: int) -> None:
15
-
16
  super().__init__()
17
  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
18
  self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')
@@ -21,45 +13,36 @@ class DoubleConv(nn.Module):
21
  return F.relu(self.conv2(F.relu(self.conv1(x))))
22
 
23
  class Up(nn.Module):
24
-
25
  def __init__(self, in_channels: int, out_channels: int) -> None:
26
  super().__init__()
27
  self.upconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
28
  self.conv = DoubleConv(in_channels=in_channels, out_channels=out_channels)
29
-
30
  def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> torch.Tensor:
31
  return self.conv(torch.cat((x_left, self.upconv(x_right)), dim=1))
32
 
33
  class CustomUnet(nn.Module):
34
 
35
- def __init__(self, in_channels: int = 3, depth: int = 3, start_channels: int = 16) -> None:
36
-
37
  super().__init__()
38
-
39
  self.input_conv = DoubleConv(in_channels, start_channels)
40
-
41
  self.encoder_layers = nn.ModuleList()
42
  for i in range(depth):
43
  self.encoder_layers.append(DoubleConv(start_channels, start_channels * 2))
44
  start_channels *= 2
45
-
46
  self.decoder_layers = nn.ModuleList()
47
  for i in range(depth):
48
  self.decoder_layers.append(Up(start_channels, start_channels // 2))
49
  start_channels //= 2
50
-
51
  self.output_conv = nn.Conv2d(start_channels, 1, kernel_size=1)
52
 
53
  def forward(self, x: torch.Tensor) -> torch.Tensor:
54
-
55
  x = self.input_conv(x)
56
  xs = [x]
57
-
58
  for encoding_layer in self.encoder_layers:
59
  x = encoding_layer(F.max_pool2d(x, 2))
60
- xs.append(x)
61
-
62
- for decoding_layer, x_left in zip(self.decoder_layers, reversed(xs[:-1])):
63
  x = decoding_layer(x_left, x)
64
-
65
  return self.output_conv(x)
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
  class DoubleConv(nn.Module):
6
+
7
  def __init__(self, in_channels: int, out_channels: int) -> None:
 
8
  super().__init__()
9
  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
10
  self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')
 
13
  return F.relu(self.conv2(F.relu(self.conv1(x))))
14
 
15
  class Up(nn.Module):
16
+
17
  def __init__(self, in_channels: int, out_channels: int) -> None:
18
  super().__init__()
19
  self.upconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
20
  self.conv = DoubleConv(in_channels=in_channels, out_channels=out_channels)
21
+
22
  def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> torch.Tensor:
23
  return self.conv(torch.cat((x_left, self.upconv(x_right)), dim=1))
24
 
25
  class CustomUnet(nn.Module):
26
 
27
+ def __init__(self, in_channels: int, depth: int, start_channels: int) -> None:
 
28
  super().__init__()
 
29
  self.input_conv = DoubleConv(in_channels, start_channels)
 
30
  self.encoder_layers = nn.ModuleList()
31
  for i in range(depth):
32
  self.encoder_layers.append(DoubleConv(start_channels, start_channels * 2))
33
  start_channels *= 2
 
34
  self.decoder_layers = nn.ModuleList()
35
  for i in range(depth):
36
  self.decoder_layers.append(Up(start_channels, start_channels // 2))
37
  start_channels //= 2
 
38
  self.output_conv = nn.Conv2d(start_channels, 1, kernel_size=1)
39
 
40
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
41
  x = self.input_conv(x)
42
  xs = [x]
 
43
  for encoding_layer in self.encoder_layers:
44
  x = encoding_layer(F.max_pool2d(x, 2))
45
+ xs.append(x)
46
+ for decoding_layer, x_left in zip(self.decoder_layers, reversed(xs[:-1]), strict=True):
 
47
  x = decoding_layer(x_left, x)
 
48
  return self.output_conv(x)
utils.py DELETED
@@ -1,20 +0,0 @@
1
- import albumentations as A
2
- from segmentation_models_pytorch import Unet
3
-
4
- val_transform = A.Compose(
5
- transforms=[
6
- A.Resize(320, 320)
7
- ],
8
- is_check_shapes=False
9
- )
10
-
11
- def get_pretrained_unet() -> Unet:
12
- unet = Unet(
13
- encoder_name='timm-efficientnet-b0',
14
- encoder_weights='imagenet',
15
- in_channels=3,
16
- encoder_depth=5,
17
- classes=1,
18
- activation=None
19
- )
20
- return unet