saicoder commited on
Commit
d224b5d
·
unverified ·
0 Parent(s):

ML for mask segmentation

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ .checkpoint
Makefile ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ dev:
2
+ uvicorn app.server:app --reload
api/index.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from app.server import app
app/cloth_segmentation/model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+
5
+ from urllib.request import urlretrieve
6
+ from collections import OrderedDict
7
+ from PIL import Image
8
+ from .network import U2NET
9
+ from torchvision import transforms
10
+ from enum import Enum
11
+ from dataclasses import dataclass
12
+ from typing import Optional
13
+
14
+ MODEL_URL = "https://huggingface.co/spaces/wildoctopus/cloth-segmentation/resolve/main/model/cloth_segm.pth"
15
+
16
+
17
+ class Mode(Enum):
18
+ BINARY = 'binary'
19
+ VARIABLE = 'variable'
20
+
21
+
22
+ @dataclass
23
+ class Result:
24
+ upper_body: Optional[Image.Image] = None
25
+ lower_body: Optional[Image.Image] = None
26
+ full_body: Optional[Image.Image] = None
27
+
28
+
29
+ def check_or_download_model(file_path):
30
+ if not os.path.exists(file_path):
31
+ print("No model found, downloading model")
32
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
33
+ urlretrieve(MODEL_URL, file_path)
34
+
35
+
36
+ def load_checkpoint(model, checkpoint_path):
37
+ model_state_dict = torch.load(
38
+ checkpoint_path, map_location=torch.device("cpu"))
39
+ new_state_dict = OrderedDict()
40
+ for k, v in model_state_dict.items():
41
+ name = k[7:] # remove `module.`
42
+ new_state_dict[name] = v
43
+
44
+ model.load_state_dict(new_state_dict)
45
+ print("----checkpoints loaded from path: {}----".format(checkpoint_path))
46
+ return model
47
+
48
+
49
+ def load_seg_model(checkpoint_path, device='cpu'):
50
+ net = U2NET(in_ch=3, out_ch=4)
51
+ check_or_download_model(checkpoint_path)
52
+ net = load_checkpoint(net, checkpoint_path)
53
+ net = net.to(device)
54
+ net = net.eval()
55
+
56
+ return net
57
+
58
+
59
+ def apply_transform(image):
60
+ transform_rgb = transforms.Compose(
61
+ [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
62
+ return transform_rgb(image)
63
+
64
+
65
+ def segment(image: Image, net: U2NET, mode=Mode.BINARY) -> Result:
66
+ original_size = image.size
67
+
68
+ image = image.resize((768, 768), Image.BICUBIC).convert('RGB')
69
+ image_tensor = apply_transform(image)
70
+ image_tensor = torch.unsqueeze(image_tensor, 0)
71
+
72
+ images = []
73
+
74
+ with torch.no_grad():
75
+ output_tensor: torch.Tensor = net(image_tensor.to("cpu"))
76
+ output_tensor = torch.nn.functional.log_softmax(
77
+ output_tensor[0], dim=1)
78
+
79
+ if mode == Mode.VARIABLE:
80
+ for mask in range(1, 4):
81
+ mask_probabilities = output_tensor.reshape((4, 768, 768))[mask]
82
+ mask_probabilities = torch.exp(mask_probabilities).numpy()
83
+
84
+ mask_probabilities = (mask_probabilities * 255).astype(np.uint8)
85
+
86
+ alpha_mask_img = Image.fromarray(mask_probabilities, mode='L')
87
+ alpha_mask_img = alpha_mask_img.resize(
88
+ original_size, Image.BICUBIC)
89
+ images.append(alpha_mask_img)
90
+ else:
91
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
92
+ output_tensor = torch.squeeze(output_tensor, dim=0)
93
+ output_arr = output_tensor.cpu().numpy()
94
+
95
+ for mask in range(1, 4):
96
+ alpha_mask = (output_arr == mask).astype(np.uint8) * 255
97
+ alpha_mask_img = Image.fromarray(alpha_mask[0], mode='L')
98
+ alpha_mask_img = alpha_mask_img.resize(
99
+ original_size, Image.BICUBIC)
100
+ images.append(alpha_mask_img)
101
+
102
+ result = Result()
103
+ result.upper_body = images[0]
104
+ result.lower_body = images[1]
105
+ result.full_body_body = images[2]
106
+
107
+ return result
app/cloth_segmentation/network.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
8
+ super(REBNCONV, self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(
11
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
12
+ )
13
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
14
+ self.relu_s1 = nn.ReLU(inplace=True)
15
+
16
+ def forward(self, x):
17
+
18
+ hx = x
19
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
20
+
21
+ return xout
22
+
23
+
24
+ # upsample tensor 'src' to have the same spatial size with tensor 'tar'
25
+ def _upsample_like(src, tar):
26
+
27
+ src = F.upsample(src, size=tar.shape[2:], mode="bilinear")
28
+
29
+ return src
30
+
31
+
32
+ ### RSU-7 ###
33
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
34
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
35
+ super(RSU7, self).__init__()
36
+
37
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
38
+
39
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
40
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
41
+
42
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
43
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
44
+
45
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
46
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
47
+
48
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
49
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
50
+
51
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
52
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
53
+
54
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
55
+
56
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
57
+
58
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
59
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
60
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
61
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
62
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
63
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
64
+
65
+ def forward(self, x):
66
+
67
+ hx = x
68
+ hxin = self.rebnconvin(hx)
69
+
70
+ hx1 = self.rebnconv1(hxin)
71
+ hx = self.pool1(hx1)
72
+
73
+ hx2 = self.rebnconv2(hx)
74
+ hx = self.pool2(hx2)
75
+
76
+ hx3 = self.rebnconv3(hx)
77
+ hx = self.pool3(hx3)
78
+
79
+ hx4 = self.rebnconv4(hx)
80
+ hx = self.pool4(hx4)
81
+
82
+ hx5 = self.rebnconv5(hx)
83
+ hx = self.pool5(hx5)
84
+
85
+ hx6 = self.rebnconv6(hx)
86
+
87
+ hx7 = self.rebnconv7(hx6)
88
+
89
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
90
+ hx6dup = _upsample_like(hx6d, hx5)
91
+
92
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
93
+ hx5dup = _upsample_like(hx5d, hx4)
94
+
95
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
96
+ hx4dup = _upsample_like(hx4d, hx3)
97
+
98
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
99
+ hx3dup = _upsample_like(hx3d, hx2)
100
+
101
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
102
+ hx2dup = _upsample_like(hx2d, hx1)
103
+
104
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
105
+
106
+ """
107
+ del hx1, hx2, hx3, hx4, hx5, hx6, hx7
108
+ del hx6d, hx5d, hx3d, hx2d
109
+ del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup
110
+ """
111
+
112
+ return hx1d + hxin
113
+
114
+
115
+ ### RSU-6 ###
116
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
117
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
118
+ super(RSU6, self).__init__()
119
+
120
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
121
+
122
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
123
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+
136
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
137
+
138
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
141
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
143
+
144
+ def forward(self, x):
145
+
146
+ hx = x
147
+
148
+ hxin = self.rebnconvin(hx)
149
+
150
+ hx1 = self.rebnconv1(hxin)
151
+ hx = self.pool1(hx1)
152
+
153
+ hx2 = self.rebnconv2(hx)
154
+ hx = self.pool2(hx2)
155
+
156
+ hx3 = self.rebnconv3(hx)
157
+ hx = self.pool3(hx3)
158
+
159
+ hx4 = self.rebnconv4(hx)
160
+ hx = self.pool4(hx4)
161
+
162
+ hx5 = self.rebnconv5(hx)
163
+
164
+ hx6 = self.rebnconv6(hx5)
165
+
166
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
167
+ hx5dup = _upsample_like(hx5d, hx4)
168
+
169
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
170
+ hx4dup = _upsample_like(hx4d, hx3)
171
+
172
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
173
+ hx3dup = _upsample_like(hx3d, hx2)
174
+
175
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
176
+ hx2dup = _upsample_like(hx2d, hx1)
177
+
178
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
179
+
180
+ """
181
+ del hx1, hx2, hx3, hx4, hx5, hx6
182
+ del hx5d, hx4d, hx3d, hx2d
183
+ del hx2dup, hx3dup, hx4dup, hx5dup
184
+ """
185
+
186
+ return hx1d + hxin
187
+
188
+
189
+ ### RSU-5 ###
190
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
191
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
192
+ super(RSU5, self).__init__()
193
+
194
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
195
+
196
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
197
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
198
+
199
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
200
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
201
+
202
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
203
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
204
+
205
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
206
+
207
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
208
+
209
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
211
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
212
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
213
+
214
+ def forward(self, x):
215
+
216
+ hx = x
217
+
218
+ hxin = self.rebnconvin(hx)
219
+
220
+ hx1 = self.rebnconv1(hxin)
221
+ hx = self.pool1(hx1)
222
+
223
+ hx2 = self.rebnconv2(hx)
224
+ hx = self.pool2(hx2)
225
+
226
+ hx3 = self.rebnconv3(hx)
227
+ hx = self.pool3(hx3)
228
+
229
+ hx4 = self.rebnconv4(hx)
230
+
231
+ hx5 = self.rebnconv5(hx4)
232
+
233
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
234
+ hx4dup = _upsample_like(hx4d, hx3)
235
+
236
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
237
+ hx3dup = _upsample_like(hx3d, hx2)
238
+
239
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
240
+ hx2dup = _upsample_like(hx2d, hx1)
241
+
242
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
243
+
244
+ """
245
+ del hx1, hx2, hx3, hx4, hx5
246
+ del hx4d, hx3d, hx2d
247
+ del hx2dup, hx3dup, hx4dup
248
+ """
249
+
250
+ return hx1d + hxin
251
+
252
+
253
+ ### RSU-4 ###
254
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
255
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
256
+ super(RSU4, self).__init__()
257
+
258
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
259
+
260
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
261
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
262
+
263
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
264
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
265
+
266
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
267
+
268
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
269
+
270
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
271
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
272
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
273
+
274
+ def forward(self, x):
275
+
276
+ hx = x
277
+
278
+ hxin = self.rebnconvin(hx)
279
+
280
+ hx1 = self.rebnconv1(hxin)
281
+ hx = self.pool1(hx1)
282
+
283
+ hx2 = self.rebnconv2(hx)
284
+ hx = self.pool2(hx2)
285
+
286
+ hx3 = self.rebnconv3(hx)
287
+
288
+ hx4 = self.rebnconv4(hx3)
289
+
290
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
291
+ hx3dup = _upsample_like(hx3d, hx2)
292
+
293
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
294
+ hx2dup = _upsample_like(hx2d, hx1)
295
+
296
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
297
+
298
+ """
299
+ del hx1, hx2, hx3, hx4
300
+ del hx3d, hx2d
301
+ del hx2dup, hx3dup
302
+ """
303
+
304
+ return hx1d + hxin
305
+
306
+
307
+ ### RSU-4F ###
308
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
309
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
310
+ super(RSU4F, self).__init__()
311
+
312
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
313
+
314
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
315
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
316
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
317
+
318
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
319
+
320
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
321
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
322
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
323
+
324
+ def forward(self, x):
325
+
326
+ hx = x
327
+
328
+ hxin = self.rebnconvin(hx)
329
+
330
+ hx1 = self.rebnconv1(hxin)
331
+ hx2 = self.rebnconv2(hx1)
332
+ hx3 = self.rebnconv3(hx2)
333
+
334
+ hx4 = self.rebnconv4(hx3)
335
+
336
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
337
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
338
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
339
+
340
+ """
341
+ del hx1, hx2, hx3, hx4
342
+ del hx3d, hx2d
343
+ """
344
+
345
+ return hx1d + hxin
346
+
347
+
348
+ ##### U^2-Net ####
349
+ class U2NET(nn.Module):
350
+ def __init__(self, in_ch=3, out_ch=1):
351
+ super(U2NET, self).__init__()
352
+
353
+ self.stage1 = RSU7(in_ch, 32, 64)
354
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
355
+
356
+ self.stage2 = RSU6(64, 32, 128)
357
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
358
+
359
+ self.stage3 = RSU5(128, 64, 256)
360
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage4 = RSU4(256, 128, 512)
363
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage5 = RSU4F(512, 256, 512)
366
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage6 = RSU4F(512, 256, 512)
369
+
370
+ # decoder
371
+ self.stage5d = RSU4F(1024, 256, 512)
372
+ self.stage4d = RSU4(1024, 128, 256)
373
+ self.stage3d = RSU5(512, 64, 128)
374
+ self.stage2d = RSU6(256, 32, 64)
375
+ self.stage1d = RSU7(128, 16, 64)
376
+
377
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
378
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
379
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
380
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
381
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
382
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
383
+
384
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
385
+
386
+ def forward(self, x):
387
+
388
+ hx = x
389
+
390
+ # stage 1
391
+ hx1 = self.stage1(hx)
392
+ hx = self.pool12(hx1)
393
+
394
+ # stage 2
395
+ hx2 = self.stage2(hx)
396
+ hx = self.pool23(hx2)
397
+
398
+ # stage 3
399
+ hx3 = self.stage3(hx)
400
+ hx = self.pool34(hx3)
401
+
402
+ # stage 4
403
+ hx4 = self.stage4(hx)
404
+ hx = self.pool45(hx4)
405
+
406
+ # stage 5
407
+ hx5 = self.stage5(hx)
408
+ hx = self.pool56(hx5)
409
+
410
+ # stage 6
411
+ hx6 = self.stage6(hx)
412
+ hx6up = _upsample_like(hx6, hx5)
413
+
414
+ # -------------------- decoder --------------------
415
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
416
+ hx5dup = _upsample_like(hx5d, hx4)
417
+
418
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
419
+ hx4dup = _upsample_like(hx4d, hx3)
420
+
421
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
422
+ hx3dup = _upsample_like(hx3d, hx2)
423
+
424
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
425
+ hx2dup = _upsample_like(hx2d, hx1)
426
+
427
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
428
+
429
+ # side output
430
+ d1 = self.side1(hx1d)
431
+
432
+ d2 = self.side2(hx2d)
433
+ d2 = _upsample_like(d2, d1)
434
+
435
+ d3 = self.side3(hx3d)
436
+ d3 = _upsample_like(d3, d1)
437
+
438
+ d4 = self.side4(hx4d)
439
+ d4 = _upsample_like(d4, d1)
440
+
441
+ d5 = self.side5(hx5d)
442
+ d5 = _upsample_like(d5, d1)
443
+
444
+ d6 = self.side6(hx6)
445
+ d6 = _upsample_like(d6, d1)
446
+
447
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
448
+
449
+ """
450
+ del hx1, hx2, hx3, hx4, hx5, hx6
451
+ del hx5d, hx4d, hx3d, hx2d, hx1d
452
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
453
+ """
454
+
455
+ return d0, d1, d2, d3, d4, d5, d6
456
+
457
+
458
+ ### U^2-Net small ###
459
+ class U2NETP(nn.Module):
460
+ def __init__(self, in_ch=3, out_ch=1):
461
+ super(U2NETP, self).__init__()
462
+
463
+ self.stage1 = RSU7(in_ch, 16, 64)
464
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
465
+
466
+ self.stage2 = RSU6(64, 16, 64)
467
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
468
+
469
+ self.stage3 = RSU5(64, 16, 64)
470
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
471
+
472
+ self.stage4 = RSU4(64, 16, 64)
473
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
474
+
475
+ self.stage5 = RSU4F(64, 16, 64)
476
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
477
+
478
+ self.stage6 = RSU4F(64, 16, 64)
479
+
480
+ # decoder
481
+ self.stage5d = RSU4F(128, 16, 64)
482
+ self.stage4d = RSU4(128, 16, 64)
483
+ self.stage3d = RSU5(128, 16, 64)
484
+ self.stage2d = RSU6(128, 16, 64)
485
+ self.stage1d = RSU7(128, 16, 64)
486
+
487
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
488
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
489
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
490
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
491
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
492
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
493
+
494
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
495
+
496
+ def forward(self, x):
497
+
498
+ hx = x
499
+
500
+ # stage 1
501
+ hx1 = self.stage1(hx)
502
+ hx = self.pool12(hx1)
503
+
504
+ # stage 2
505
+ hx2 = self.stage2(hx)
506
+ hx = self.pool23(hx2)
507
+
508
+ # stage 3
509
+ hx3 = self.stage3(hx)
510
+ hx = self.pool34(hx3)
511
+
512
+ # stage 4
513
+ hx4 = self.stage4(hx)
514
+ hx = self.pool45(hx4)
515
+
516
+ # stage 5
517
+ hx5 = self.stage5(hx)
518
+ hx = self.pool56(hx5)
519
+
520
+ # stage 6
521
+ hx6 = self.stage6(hx)
522
+ hx6up = _upsample_like(hx6, hx5)
523
+
524
+ # decoder
525
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
526
+ hx5dup = _upsample_like(hx5d, hx4)
527
+
528
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
529
+ hx4dup = _upsample_like(hx4d, hx3)
530
+
531
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
532
+ hx3dup = _upsample_like(hx3d, hx2)
533
+
534
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
535
+ hx2dup = _upsample_like(hx2d, hx1)
536
+
537
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
538
+
539
+ # side output
540
+ d1 = self.side1(hx1d)
541
+
542
+ d2 = self.side2(hx2d)
543
+ d2 = _upsample_like(d2, d1)
544
+
545
+ d3 = self.side3(hx3d)
546
+ d3 = _upsample_like(d3, d1)
547
+
548
+ d4 = self.side4(hx4d)
549
+ d4 = _upsample_like(d4, d1)
550
+
551
+ d5 = self.side5(hx5d)
552
+ d5 = _upsample_like(d5, d1)
553
+
554
+ d6 = self.side6(hx6)
555
+ d6 = _upsample_like(d6, d1)
556
+
557
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
558
+
559
+ return d0, d1, d2, d3, d4, d5, d6
app/server.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from fastapi import FastAPI, UploadFile
3
+ from contextlib import asynccontextmanager
4
+
5
+ from pydantic.dataclasses import dataclass
6
+ from app.cloth_segmentation.model import segment, Mode, load_seg_model, Result, U2NET
7
+ from PIL import Image
8
+
9
+ from base64 import b64encode
10
+ from io import BytesIO
11
+
12
+
13
+ # === DTOs ==
14
+
15
+
16
+ @dataclass
17
+ class MaskResponse:
18
+ upper_body: Optional[str]
19
+ lower_body: Optional[str]
20
+ full_body: Optional[str]
21
+
22
+ def __image_to_base64(image: Image.Image | None):
23
+ if image == None:
24
+ return None
25
+ buffered = BytesIO()
26
+ image.save(buffered, format="JPEG")
27
+ return "data:image/png;base64," + b64encode(buffered.getvalue()).decode()
28
+
29
+ @classmethod
30
+ def from_segmentation(self, data: Result):
31
+ return MaskResponse(upper_body=self.__image_to_base64(data.upper_body),
32
+ lower_body=self.__image_to_base64(data.lower_body),
33
+ full_body=self.__image_to_base64(data.full_body),)
34
+
35
+
36
+ # === Context ===
37
+ ml_models = {}
38
+
39
+
40
+ @asynccontextmanager
41
+ async def lifespan(app: FastAPI):
42
+ ml_models["cloth_segmentation"] = load_seg_model(".checkpoint/model.pth")
43
+ yield
44
+ ml_models.clear()
45
+
46
+
47
+ app = FastAPI(lifespan=lifespan)
48
+ # === Routes
49
+
50
+
51
+ @app.get("/")
52
+ def index():
53
+ return {"ok": True, "message": "Invalid route, please check `/docs`"}
54
+
55
+
56
+ @app.post("/mask")
57
+ def mask(upload: UploadFile) -> MaskResponse:
58
+ net = ml_models["cloth_segmentation"]
59
+ image = Image.open(upload.file)
60
+ result = segment(image, net, mode=Mode.BINARY)
61
+
62
+ return MaskResponse.from_segmentation(result)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ uvicorn==0.23.2
2
+ fastapi==0.104.1
3
+ pydantic==2.4.2
4
+
5
+ Pillow==9.5.0
6
+ numpy==1.24.3
7
+ torch==2.1.0
8
+ torchvision==0.16.0