Spaces:
Sleeping
Sleeping
updated
Browse files- app.py +24 -16
- custom_unet.py +6 -76
- early_stopper.py +0 -23
- model.py +0 -130
- models/{custom_unet.pth → custom_unet.pt} +2 -2
- models/{unet.pth → pretrained_unet.pt} +2 -2
- requirements.txt +1 -2
- unet.py +0 -67
- utils.py +20 -0
app.py
CHANGED
|
@@ -1,36 +1,44 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from PIL import Image
|
| 3 |
import os
|
| 4 |
-
|
| 5 |
import torch
|
| 6 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
from model import SegmentationModel
|
| 9 |
from custom_unet import CustomUnet
|
| 10 |
-
from
|
| 11 |
|
| 12 |
-
from typing import Dict, Union, Tuple, List
|
| 13 |
|
| 14 |
class GradioApp:
|
| 15 |
|
| 16 |
def __init__(self) -> None:
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
}
|
| 22 |
|
| 23 |
def predict(self, img_file: str, model_name: str) -> Tuple[str, List[Tuple[np.ndarray, str]]]:
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
self.models[model_name].
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
return img_file, [(
|
| 33 |
-
|
| 34 |
def launch(self):
|
| 35 |
|
| 36 |
examples_list = [['examples/' + example] for example in os.listdir('examples')]
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from PIL import Image
|
| 3 |
import os
|
|
|
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
from torchvision.transforms.functional import resize
|
| 8 |
+
from typing import Tuple, List
|
| 9 |
|
|
|
|
| 10 |
from custom_unet import CustomUnet
|
| 11 |
+
from utils import val_transform, get_pretrained_unet
|
| 12 |
|
|
|
|
| 13 |
|
| 14 |
class GradioApp:
|
| 15 |
|
| 16 |
def __init__(self) -> None:
|
| 17 |
|
| 18 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 19 |
+
|
| 20 |
+
custom = CustomUnet().to(device).eval()
|
| 21 |
+
custom.load_state_dict(torch.load('models/custom_unet.pt', map_location=device))
|
| 22 |
+
|
| 23 |
+
pretrained = get_pretrained_unet().to(device).eval()
|
| 24 |
+
pretrained.load_state_dict(torch.load('models/pretrained_unet.pt', map_location=device))
|
| 25 |
+
|
| 26 |
+
self.models = {
|
| 27 |
+
'Custom': custom,
|
| 28 |
+
'Pretrained': pretrained
|
| 29 |
}
|
| 30 |
|
| 31 |
def predict(self, img_file: str, model_name: str) -> Tuple[str, List[Tuple[np.ndarray, str]]]:
|
| 32 |
|
| 33 |
+
image = image=np.asarray(Image.open(img_file))
|
| 34 |
+
h,w = image.shape[:-1]
|
| 35 |
+
image = torch.from_numpy(val_transform(image=image)['image']).float().permute(2,0,1) / 255.
|
| 36 |
+
with torch.inference_mode():
|
| 37 |
+
prediction = self.models[model_name](image.to(self.device).unsqueeze(0))[0].sigmoid().round().cpu()
|
| 38 |
+
mask = resize(img=prediction, size=(h,w), interpolation=transforms.InterpolationMode.NEAREST)[0].numpy()
|
| 39 |
+
|
| 40 |
+
return img_file, [(mask, 'person')]
|
| 41 |
+
|
| 42 |
def launch(self):
|
| 43 |
|
| 44 |
examples_list = [['examples/' + example] for example in os.listdir('examples')]
|
custom_unet.py
CHANGED
|
@@ -7,46 +7,30 @@ Additional things: https://towardsdatascience.com/understanding-u-net-61276b10f3
|
|
| 7 |
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
from model import SegmentationModel
|
| 13 |
-
from early_stopper import EarlyStopper
|
| 14 |
-
|
| 15 |
-
from typing import Tuple, Union, Optional
|
| 16 |
-
|
| 17 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 18 |
-
|
| 19 |
-
class DiceLoss(nn.Module):
|
| 20 |
-
|
| 21 |
-
def forward(self, logits: torch.Tensor, mask_true: torch.Tensor):
|
| 22 |
-
logits = torch.sigmoid(logits) > .5
|
| 23 |
-
intersection = (logits * mask_true).sum()
|
| 24 |
-
union = logits.sum() + mask_true.sum()
|
| 25 |
-
return 2 * intersection / union
|
| 26 |
|
| 27 |
class DoubleConv(nn.Module):
|
| 28 |
|
| 29 |
def __init__(self, in_channels: int, out_channels: int) -> None:
|
| 30 |
|
| 31 |
super().__init__()
|
| 32 |
-
self.relu = nn.ReLU()
|
| 33 |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
|
| 34 |
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')
|
| 35 |
|
| 36 |
def forward(self, x: torch.Tensor):
|
| 37 |
-
return
|
| 38 |
|
| 39 |
class Up(nn.Module):
|
| 40 |
|
| 41 |
-
def __init__(self, in_channels, out_channels) -> None:
|
| 42 |
super().__init__()
|
| 43 |
self.upconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
|
| 44 |
self.conv = DoubleConv(in_channels=in_channels, out_channels=out_channels)
|
| 45 |
|
| 46 |
-
def forward(self, x_left, x_right):
|
| 47 |
return self.conv(torch.cat((x_left, self.upconv(x_right)), dim=1))
|
| 48 |
|
| 49 |
-
class
|
| 50 |
|
| 51 |
def __init__(self, in_channels: int = 3, depth: int = 3, start_channels: int = 16) -> None:
|
| 52 |
|
|
@@ -65,8 +49,6 @@ class UnetModel(nn.Module):
|
|
| 65 |
start_channels //= 2
|
| 66 |
|
| 67 |
self.output_conv = nn.Conv2d(start_channels, 1, kernel_size=1)
|
| 68 |
-
|
| 69 |
-
self.pool = nn.MaxPool2d(2, 2)
|
| 70 |
|
| 71 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 72 |
|
|
@@ -74,62 +56,10 @@ class UnetModel(nn.Module):
|
|
| 74 |
xs = [x]
|
| 75 |
|
| 76 |
for encoding_layer in self.encoder_layers:
|
| 77 |
-
x = encoding_layer(
|
| 78 |
xs.append(x)
|
| 79 |
|
| 80 |
for decoding_layer, x_left in zip(self.decoder_layers, reversed(xs[:-1])):
|
| 81 |
x = decoding_layer(x_left, x)
|
| 82 |
|
| 83 |
return self.output_conv(x)
|
| 84 |
-
|
| 85 |
-
class CustomUnet(SegmentationModel):
|
| 86 |
-
|
| 87 |
-
def __init__(self,
|
| 88 |
-
name: str = 'default_name',
|
| 89 |
-
from_file: bool = True,
|
| 90 |
-
image_size: Tuple[int, int] = (320, 320),
|
| 91 |
-
in_channels: int = 3,
|
| 92 |
-
start_channels: int = 16,
|
| 93 |
-
encoder_depth: int = 5,
|
| 94 |
-
device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> None:
|
| 95 |
-
|
| 96 |
-
super().__init__()
|
| 97 |
-
|
| 98 |
-
assert image_size[0] % (2**encoder_depth) == 0
|
| 99 |
-
assert image_size[1] % (2**encoder_depth) == 0
|
| 100 |
-
|
| 101 |
-
self.name = name
|
| 102 |
-
self.image_size = image_size
|
| 103 |
-
self.in_channels = in_channels
|
| 104 |
-
self.device = device
|
| 105 |
-
|
| 106 |
-
self.save_path = f'models/{name}.pth'
|
| 107 |
-
|
| 108 |
-
if from_file:
|
| 109 |
-
self.unet = torch.load(self.save_path, map_location=device)
|
| 110 |
-
else:
|
| 111 |
-
self.unet = UnetModel(in_channels=in_channels, depth=encoder_depth, start_channels=start_channels).to(device)
|
| 112 |
-
|
| 113 |
-
self.bce_loss = nn.BCEWithLogitsLoss()
|
| 114 |
-
self.dice_loss = DiceLoss()
|
| 115 |
-
self.loss_fn = lambda logits, masks: self.bce_loss(logits, masks) + self.dice_loss(logits, masks)
|
| 116 |
-
|
| 117 |
-
def configure_optimizers(self, **kwargs):
|
| 118 |
-
self.optimizer = torch.optim.Adam(params=self.unet.parameters(), lr=kwargs['lr'])
|
| 119 |
-
self.early_stopper = EarlyStopper(patience=kwargs['patience'])
|
| 120 |
-
|
| 121 |
-
def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 122 |
-
logits = self.unet(images)
|
| 123 |
-
if masks is None:
|
| 124 |
-
return logits
|
| 125 |
-
return logits, self.loss_fn(logits, masks)
|
| 126 |
-
|
| 127 |
-
def save(self) -> None:
|
| 128 |
-
# Save the whole model, not only the state dict, so that it will work for different unets
|
| 129 |
-
torch.save(self.unet, self.save_path)
|
| 130 |
-
|
| 131 |
-
def print_summary(self, batch_size: int = 16) -> None:
|
| 132 |
-
|
| 133 |
-
print(summary(self.unet, input_size=(batch_size, self.in_channels, *self.image_size),
|
| 134 |
-
col_names=['input_size', 'output_size', 'num_params'],
|
| 135 |
-
row_settings=['var_names']))
|
|
|
|
| 7 |
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
class DoubleConv(nn.Module):
|
| 13 |
|
| 14 |
def __init__(self, in_channels: int, out_channels: int) -> None:
|
| 15 |
|
| 16 |
super().__init__()
|
|
|
|
| 17 |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
|
| 18 |
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')
|
| 19 |
|
| 20 |
def forward(self, x: torch.Tensor):
|
| 21 |
+
return F.relu(self.conv2(F.relu(self.conv1(x))))
|
| 22 |
|
| 23 |
class Up(nn.Module):
|
| 24 |
|
| 25 |
+
def __init__(self, in_channels: int, out_channels: int) -> None:
|
| 26 |
super().__init__()
|
| 27 |
self.upconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
|
| 28 |
self.conv = DoubleConv(in_channels=in_channels, out_channels=out_channels)
|
| 29 |
|
| 30 |
+
def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> torch.Tensor:
|
| 31 |
return self.conv(torch.cat((x_left, self.upconv(x_right)), dim=1))
|
| 32 |
|
| 33 |
+
class CustomUnet(nn.Module):
|
| 34 |
|
| 35 |
def __init__(self, in_channels: int = 3, depth: int = 3, start_channels: int = 16) -> None:
|
| 36 |
|
|
|
|
| 49 |
start_channels //= 2
|
| 50 |
|
| 51 |
self.output_conv = nn.Conv2d(start_channels, 1, kernel_size=1)
|
|
|
|
|
|
|
| 52 |
|
| 53 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 54 |
|
|
|
|
| 56 |
xs = [x]
|
| 57 |
|
| 58 |
for encoding_layer in self.encoder_layers:
|
| 59 |
+
x = encoding_layer(F.max_pool2d(x, 2))
|
| 60 |
xs.append(x)
|
| 61 |
|
| 62 |
for decoding_layer, x_left in zip(self.decoder_layers, reversed(xs[:-1])):
|
| 63 |
x = decoding_layer(x_left, x)
|
| 64 |
|
| 65 |
return self.output_conv(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
early_stopper.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
"""This module contains a class that implements early stopping regularization technique"""
|
| 2 |
-
|
| 3 |
-
class EarlyStopper:
|
| 4 |
-
|
| 5 |
-
def __init__(self, patience: int = 2):
|
| 6 |
-
|
| 7 |
-
self.patience = patience
|
| 8 |
-
self.best_loss = float('inf')
|
| 9 |
-
self.counter = 0
|
| 10 |
-
self.save_model = False
|
| 11 |
-
|
| 12 |
-
def check(self, validation_loss: float) -> bool:
|
| 13 |
-
|
| 14 |
-
self.save_model = False
|
| 15 |
-
if validation_loss > self.best_loss:
|
| 16 |
-
self.counter += 1
|
| 17 |
-
if self.counter == self.patience:
|
| 18 |
-
return True
|
| 19 |
-
else:
|
| 20 |
-
self.best_loss = validation_loss
|
| 21 |
-
self.counter = 0
|
| 22 |
-
self.save_model = True
|
| 23 |
-
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
DELETED
|
@@ -1,130 +0,0 @@
|
|
| 1 |
-
"""This module contains the base class for segmentation models"""
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
from torch.utils.data import DataLoader
|
| 6 |
-
from torchvision.utils import draw_segmentation_masks
|
| 7 |
-
from torchvision.transforms.functional import resize
|
| 8 |
-
from torch.utils.tensorboard.writer import SummaryWriter
|
| 9 |
-
|
| 10 |
-
import numpy as np
|
| 11 |
-
import cv2 as cv
|
| 12 |
-
import albumentations as A
|
| 13 |
-
|
| 14 |
-
from typing import Optional, Union, Tuple, Literal
|
| 15 |
-
|
| 16 |
-
from early_stopper import EarlyStopper
|
| 17 |
-
|
| 18 |
-
class SegmentationModel(nn.Module):
|
| 19 |
-
|
| 20 |
-
name: str = "base name"
|
| 21 |
-
device: Literal['cpu', 'cuda'] = None
|
| 22 |
-
|
| 23 |
-
optimizer: torch.optim.Optimizer = None
|
| 24 |
-
early_stopper: EarlyStopper = None
|
| 25 |
-
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None
|
| 26 |
-
save_path: str = None
|
| 27 |
-
image_size: Tuple[int, int] = None
|
| 28 |
-
|
| 29 |
-
def configure_optimizers(self, **kwargs) -> None:
|
| 30 |
-
raise NotImplementedError()
|
| 31 |
-
|
| 32 |
-
def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 33 |
-
|
| 34 |
-
raise NotImplementedError()
|
| 35 |
-
|
| 36 |
-
def _train_step(self, data_loader: DataLoader) -> float:
|
| 37 |
-
|
| 38 |
-
self.train()
|
| 39 |
-
total_loss = 0.
|
| 40 |
-
for images, masks in data_loader:
|
| 41 |
-
images, masks = images.to(self.device), masks.to(self.device)
|
| 42 |
-
|
| 43 |
-
self.optimizer.zero_grad()
|
| 44 |
-
logits, loss = self(images, masks)
|
| 45 |
-
loss.backward()
|
| 46 |
-
self.optimizer.step()
|
| 47 |
-
|
| 48 |
-
total_loss += loss.item()
|
| 49 |
-
|
| 50 |
-
return total_loss / len(data_loader)
|
| 51 |
-
|
| 52 |
-
def _test_step(self, data_loader: DataLoader) -> float:
|
| 53 |
-
|
| 54 |
-
self.eval()
|
| 55 |
-
total_loss = 0.
|
| 56 |
-
with torch.inference_mode():
|
| 57 |
-
for images, masks in data_loader:
|
| 58 |
-
images, masks = images.to(self.device), masks.to(self.device)
|
| 59 |
-
logits, loss = self(images, masks)
|
| 60 |
-
total_loss += loss.item()
|
| 61 |
-
return total_loss / len(data_loader)
|
| 62 |
-
|
| 63 |
-
def train_model(self, train_loader: DataLoader, test_loader: DataLoader, epochs: int, log_dir: str) -> None:
|
| 64 |
-
|
| 65 |
-
writer = SummaryWriter(log_dir=f'{log_dir}/{self.name}')
|
| 66 |
-
|
| 67 |
-
for i in range(epochs):
|
| 68 |
-
train_loss = self._train_step(train_loader)
|
| 69 |
-
test_loss = self._test_step(test_loader)
|
| 70 |
-
|
| 71 |
-
if self.early_stopper is not None:
|
| 72 |
-
if self.early_stopper.check(test_loss):
|
| 73 |
-
print(f'Model stopped early due to risk of overfitting')
|
| 74 |
-
break
|
| 75 |
-
|
| 76 |
-
if self.early_stopper.save_model:
|
| 77 |
-
self.save()
|
| 78 |
-
print('saved model')
|
| 79 |
-
|
| 80 |
-
if self.lr_scheduler is not None:
|
| 81 |
-
self.lr_scheduler.step()
|
| 82 |
-
|
| 83 |
-
print(f'{i}: Train loss: {train_loss :.2} | Test loss: {test_loss :.2}')
|
| 84 |
-
|
| 85 |
-
writer.add_scalars(main_tag='Loss over time',
|
| 86 |
-
tag_scalar_dict={'train loss': train_loss, 'test loss': test_loss},
|
| 87 |
-
global_step=i)
|
| 88 |
-
|
| 89 |
-
else:
|
| 90 |
-
if self.early_stopper is not None:
|
| 91 |
-
print('Model did not converge. Possibility of underfitting')
|
| 92 |
-
self.save()
|
| 93 |
-
writer.close()
|
| 94 |
-
|
| 95 |
-
def save(self) -> None:
|
| 96 |
-
raise NotImplementedError()
|
| 97 |
-
|
| 98 |
-
def predict(self,
|
| 99 |
-
test_image_path: str,
|
| 100 |
-
option: Literal['mask', 'image_with_mask', 'mask_and_image_with_mask'] = 'image_with_mask'
|
| 101 |
-
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
| 102 |
-
|
| 103 |
-
self.eval()
|
| 104 |
-
input_resizer = A.Resize(*self.image_size)
|
| 105 |
-
|
| 106 |
-
original_image = cv.cvtColor(cv.imread(test_image_path), cv.COLOR_BGR2RGB)
|
| 107 |
-
original_image_tensor = torch.from_numpy(original_image).permute(2,0,1).type(torch.uint8)
|
| 108 |
-
resized_image_tensor = (torch.from_numpy(input_resizer(image=original_image)['image']).float() / 255.).permute(2,0,1)
|
| 109 |
-
|
| 110 |
-
with torch.inference_mode():
|
| 111 |
-
logits = self(resized_image_tensor.unsqueeze(0).to(self.device)).squeeze(0).cpu().detach()
|
| 112 |
-
probs = torch.sigmoid(logits)
|
| 113 |
-
resized_mask_tensor = probs > .5
|
| 114 |
-
|
| 115 |
-
original_mask_tensor = resize(resized_mask_tensor, size=original_image.shape[:-1], antialias=True)
|
| 116 |
-
|
| 117 |
-
image_with_mask = draw_segmentation_masks(image=original_image_tensor,
|
| 118 |
-
masks=original_mask_tensor,
|
| 119 |
-
alpha=.5,
|
| 120 |
-
colors='white')
|
| 121 |
-
|
| 122 |
-
if option == 'mask':
|
| 123 |
-
return original_mask_tensor.numpy()
|
| 124 |
-
if option == 'image_with_mask':
|
| 125 |
-
return image_with_mask.permute(1,2,0).numpy()
|
| 126 |
-
if option == 'mask_and_image_with_mask':
|
| 127 |
-
return original_mask_tensor.numpy(), image_with_mask.permute(1,2,0).numpy()
|
| 128 |
-
|
| 129 |
-
def print_summary(self) -> None:
|
| 130 |
-
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/{custom_unet.pth → custom_unet.pt}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a227b749031ac42b97c9833bd18e8c37b5104bb94546f2063310a73a9b912fe5
|
| 3 |
+
size 1941386
|
models/{unet.pth → pretrained_unet.pt}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b3fe6a191c10e927901529593049ba76ccbab708ec62a8d25fa2f8b46cb4ac2
|
| 3 |
+
size 25339050
|
requirements.txt
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
torch
|
| 2 |
-
torchinfo
|
| 3 |
segmentation-models-pytorch
|
| 4 |
albumentations
|
| 5 |
opencv-python
|
| 6 |
gradio
|
| 7 |
numpy
|
| 8 |
matplotlib
|
| 9 |
-
tensorboard
|
|
|
|
| 1 |
torch
|
|
|
|
| 2 |
segmentation-models-pytorch
|
| 3 |
albumentations
|
| 4 |
opencv-python
|
| 5 |
gradio
|
| 6 |
numpy
|
| 7 |
matplotlib
|
| 8 |
+
tensorboard
|
unet.py
DELETED
|
@@ -1,67 +0,0 @@
|
|
| 1 |
-
"""This module defines a Unet architecture"""
|
| 2 |
-
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch
|
| 5 |
-
from torchinfo import summary
|
| 6 |
-
import segmentation_models_pytorch as smp
|
| 7 |
-
from early_stopper import EarlyStopper
|
| 8 |
-
|
| 9 |
-
from model import SegmentationModel
|
| 10 |
-
|
| 11 |
-
from typing import Optional, Union, Tuple
|
| 12 |
-
|
| 13 |
-
class Unet(SegmentationModel):
|
| 14 |
-
|
| 15 |
-
def __init__(self,
|
| 16 |
-
name: str = 'default_name',
|
| 17 |
-
from_file: bool = True,
|
| 18 |
-
image_size: Tuple[int, int] = (320, 320),
|
| 19 |
-
encoder_name: str = 'timm-efficientnet-b0',
|
| 20 |
-
pretrained: bool = True,
|
| 21 |
-
in_channels: int = 3,
|
| 22 |
-
encoder_depth: int = 5,
|
| 23 |
-
device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> None:
|
| 24 |
-
|
| 25 |
-
super().__init__()
|
| 26 |
-
|
| 27 |
-
self.name = name
|
| 28 |
-
self.image_size = image_size
|
| 29 |
-
self.in_channels = in_channels
|
| 30 |
-
self.device = device
|
| 31 |
-
|
| 32 |
-
self.save_path = f'models/{name}.pth'
|
| 33 |
-
|
| 34 |
-
if from_file:
|
| 35 |
-
self.unet = torch.load(self.save_path, map_location=device)
|
| 36 |
-
else:
|
| 37 |
-
self.unet = smp.Unet(
|
| 38 |
-
encoder_name=encoder_name,
|
| 39 |
-
encoder_weights='imagenet' if pretrained else None,
|
| 40 |
-
in_channels=in_channels,
|
| 41 |
-
encoder_depth=encoder_depth,
|
| 42 |
-
classes=1,
|
| 43 |
-
activation=None
|
| 44 |
-
).to(device)
|
| 45 |
-
|
| 46 |
-
bce_loss_fn = nn.BCEWithLogitsLoss()
|
| 47 |
-
dice_loss_fn = smp.losses.DiceLoss(mode='binary')
|
| 48 |
-
self.loss_fn = lambda logits, masks: bce_loss_fn(logits, masks) + dice_loss_fn(logits, masks)
|
| 49 |
-
|
| 50 |
-
def configure_optimizers(self, **kwargs):
|
| 51 |
-
self.optimizer = torch.optim.Adam(params=self.unet.parameters(), lr=kwargs['lr'])
|
| 52 |
-
self.early_stopper = EarlyStopper(patience=kwargs['patience'])
|
| 53 |
-
|
| 54 |
-
def forward(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 55 |
-
logits = self.unet(images)
|
| 56 |
-
if masks is None:
|
| 57 |
-
return logits
|
| 58 |
-
return logits, self.loss_fn(logits, masks)
|
| 59 |
-
|
| 60 |
-
def save(self) -> None:
|
| 61 |
-
# Save the whole model, not only the state dict, so that it will work for different unets
|
| 62 |
-
torch.save(self.unet, self.save_path)
|
| 63 |
-
|
| 64 |
-
def print_summary(self, batch_size: int = 16) -> None:
|
| 65 |
-
print(summary(self.unet, input_size=(batch_size, self.in_channels, *self.image_size),
|
| 66 |
-
col_names=['input_size', 'output_size', 'num_params'],
|
| 67 |
-
row_settings=['var_names']))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import albumentations as A
|
| 2 |
+
from segmentation_models_pytorch import Unet
|
| 3 |
+
|
| 4 |
+
val_transform = A.Compose(
|
| 5 |
+
transforms=[
|
| 6 |
+
A.Resize(320, 320)
|
| 7 |
+
],
|
| 8 |
+
is_check_shapes=False
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
def get_pretrained_unet() -> Unet:
|
| 12 |
+
unet = Unet(
|
| 13 |
+
encoder_name='timm-efficientnet-b0',
|
| 14 |
+
encoder_weights='imagenet',
|
| 15 |
+
in_channels=3,
|
| 16 |
+
encoder_depth=5,
|
| 17 |
+
classes=1,
|
| 18 |
+
activation=None
|
| 19 |
+
)
|
| 20 |
+
return unet
|