Ayesha-Majeed commited on
Commit
a5500e5
·
verified ·
1 Parent(s): 910d021

Update binary_segmentation.py

Browse files
Files changed (1) hide show
  1. binary_segmentation.py +387 -79
binary_segmentation.py CHANGED
@@ -1,11 +1,6 @@
1
  """
2
  Binary Image Segmentation Tool
3
- A lightweight, professional implementation for foreground object segmentation.
4
-
5
- Supports multiple models:
6
- - U2NETP (fastest, 1.1M params)
7
- - BiRefNet (best accuracy, larger model)
8
- - RMBG (good balance)
9
  """
10
 
11
  import os
@@ -14,10 +9,11 @@ from pathlib import Path
14
  from typing import Literal, Tuple, Optional
15
  import numpy as np
16
  import torch
 
 
17
  from PIL import Image
18
  from torchvision import transforms
19
  import cv2
20
- from u2net import U2NETP
21
 
22
  # Configure logging
23
  logging.basicConfig(
@@ -31,97 +27,409 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
  logger.info(f"Using device: {DEVICE}")
32
 
33
 
34
- class U2NETP(torch.nn.Module):
35
- """U2-Net Portrait (U2NETP) - Lightweight segmentation model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def __init__(self, in_ch=3, out_ch=1):
38
  super(U2NETP, self).__init__()
39
-
40
- # Encoder
41
- self.stage1 = self._make_stage(in_ch, 16, 64)
42
- self.pool12 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
43
-
44
- self.stage2 = self._make_stage(64, 16, 64)
45
- self.pool23 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
46
-
47
- self.stage3 = self._make_stage(64, 16, 64)
48
- self.pool34 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
49
-
50
- self.stage4 = self._make_stage(64, 16, 64)
51
-
52
- # Bridge
53
- self.stage5 = self._make_stage(64, 16, 64)
54
-
55
- # Decoder
56
- self.stage4d = self._make_stage(128, 16, 64)
57
- self.stage3d = self._make_stage(128, 16, 64)
58
- self.stage2d = self._make_stage(128, 16, 64)
59
- self.stage1d = self._make_stage(128, 16, 64)
60
-
61
- # Side outputs
62
- self.side1 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
63
- self.side2 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
64
- self.side3 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
65
- self.side4 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
66
- self.side5 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
67
-
68
- # Output fusion
69
- self.outconv = torch.nn.Conv2d(5 * out_ch, out_ch, 1)
70
-
71
- def _make_stage(self, in_ch, mid_ch, out_ch):
72
- return torch.nn.Sequential(
73
- torch.nn.Conv2d(in_ch, mid_ch, 3, padding=1),
74
- torch.nn.ReLU(inplace=True),
75
- torch.nn.Conv2d(mid_ch, mid_ch, 3, padding=1),
76
- torch.nn.ReLU(inplace=True),
77
- torch.nn.Conv2d(mid_ch, out_ch, 3, padding=1),
78
- torch.nn.ReLU(inplace=True)
79
- )
80
-
81
  def forward(self, x):
82
  hx = x
83
-
84
- # Encoder
85
  hx1 = self.stage1(hx)
86
  hx = self.pool12(hx1)
87
-
 
88
  hx2 = self.stage2(hx)
89
  hx = self.pool23(hx2)
90
-
 
91
  hx3 = self.stage3(hx)
92
  hx = self.pool34(hx3)
93
-
 
94
  hx4 = self.stage4(hx)
95
- hx5 = self.stage5(hx4)
96
-
97
- # Decoder
98
- hx4d = self.stage4d(torch.cat((hx5, hx4), 1))
99
- hx4dup = torch.nn.functional.interpolate(hx4d, scale_factor=2, mode='bilinear', align_corners=True)
100
-
 
 
 
 
 
 
 
 
 
 
 
101
  hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
102
- hx3dup = torch.nn.functional.interpolate(hx3d, scale_factor=2, mode='bilinear', align_corners=True)
103
-
104
  hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
105
- hx2dup = torch.nn.functional.interpolate(hx2d, scale_factor=2, mode='bilinear', align_corners=True)
106
-
107
  hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
108
-
109
- # Side outputs
110
  d1 = self.side1(hx1d)
111
- d2 = torch.nn.functional.interpolate(self.side2(hx2d), size=d1.shape[2:], mode='bilinear', align_corners=True)
112
- d3 = torch.nn.functional.interpolate(self.side3(hx3d), size=d1.shape[2:], mode='bilinear', align_corners=True)
113
- d4 = torch.nn.functional.interpolate(self.side4(hx4d), size=d1.shape[2:], mode='bilinear', align_corners=True)
114
- d5 = torch.nn.functional.interpolate(self.side5(hx5), size=d1.shape[2:], mode='bilinear', align_corners=True)
115
-
116
- # Fusion
117
- d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5), 1))
118
-
119
- return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5)
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  class BinarySegmenter:
123
  """
124
- Professional binary segmentation tool with multiple model backends.
125
 
126
  Args:
127
  model_type: Choice of segmentation model
@@ -159,7 +467,7 @@ class BinarySegmenter:
159
  logger.info(f"{self.model_type} loaded successfully")
160
 
161
  def _load_u2netp(self):
162
- """Load U2NETP model (1.1M parameters, fastest)"""
163
  self.model = U2NETP(3, 1)
164
 
165
  # Try to load pretrained weights
 
1
  """
2
  Binary Image Segmentation Tool
3
+ Uses the ORIGINAL U2NETP architecture to match pretrained weights perfectly.
 
 
 
 
 
4
  """
5
 
6
  import os
 
9
  from typing import Literal, Tuple, Optional
10
  import numpy as np
11
  import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
  from PIL import Image
15
  from torchvision import transforms
16
  import cv2
 
17
 
18
  # Configure logging
19
  logging.basicConfig(
 
27
  logger.info(f"Using device: {DEVICE}")
28
 
29
 
30
+ # ==================== ORIGINAL U2NETP ARCHITECTURE ====================
31
+ # This matches your pretrained weights EXACTLY!
32
+
33
+ class REBNCONV(nn.Module):
34
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
35
+ super(REBNCONV, self).__init__()
36
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1*dirate, dilation=1*dirate)
37
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
38
+ self.relu_s1 = nn.ReLU(inplace=True)
39
+
40
+ def forward(self, x):
41
+ hx = x
42
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
43
+ return xout
44
+
45
+
46
+ def _upsample_like(src, tar):
47
+ src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
48
+ return src
49
+
50
+
51
+ class RSU7(nn.Module):
52
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
53
+ super(RSU7, self).__init__()
54
+
55
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
56
+
57
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
58
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
59
+
60
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
61
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
62
+
63
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
64
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
65
+
66
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
67
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
68
+
69
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
70
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
71
+
72
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
73
+
74
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
75
+
76
+ self.rebnconv6d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
77
+ self.rebnconv5d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
78
+ self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
79
+ self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
80
+ self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
81
+ self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
82
+
83
+ def forward(self, x):
84
+ hx = x
85
+ hxin = self.rebnconvin(hx)
86
+
87
+ hx1 = self.rebnconv1(hxin)
88
+ hx = self.pool1(hx1)
89
+
90
+ hx2 = self.rebnconv2(hx)
91
+ hx = self.pool2(hx2)
92
+
93
+ hx3 = self.rebnconv3(hx)
94
+ hx = self.pool3(hx3)
95
+
96
+ hx4 = self.rebnconv4(hx)
97
+ hx = self.pool4(hx4)
98
+
99
+ hx5 = self.rebnconv5(hx)
100
+ hx = self.pool5(hx5)
101
+
102
+ hx6 = self.rebnconv6(hx)
103
+
104
+ hx7 = self.rebnconv7(hx6)
105
+
106
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
107
+ hx6dup = _upsample_like(hx6d, hx5)
108
+
109
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
110
+ hx5dup = _upsample_like(hx5d, hx4)
111
+
112
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
113
+ hx4dup = _upsample_like(hx4d, hx3)
114
+
115
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
116
+ hx3dup = _upsample_like(hx3d, hx2)
117
+
118
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
119
+ hx2dup = _upsample_like(hx2d, hx1)
120
+
121
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
122
+
123
+ return hx1d + hxin
124
+
125
+
126
+ class RSU6(nn.Module):
127
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
128
+ super(RSU6, self).__init__()
129
+
130
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
131
+
132
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
133
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
134
+
135
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
136
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
137
+
138
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
139
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
140
+
141
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
142
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
143
+
144
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
145
+
146
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
147
+
148
+ self.rebnconv5d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
149
+ self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
150
+ self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
151
+ self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
152
+ self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
153
+
154
+ def forward(self, x):
155
+ hx = x
156
+ hxin = self.rebnconvin(hx)
157
+
158
+ hx1 = self.rebnconv1(hxin)
159
+ hx = self.pool1(hx1)
160
+
161
+ hx2 = self.rebnconv2(hx)
162
+ hx = self.pool2(hx2)
163
+
164
+ hx3 = self.rebnconv3(hx)
165
+ hx = self.pool3(hx3)
166
+
167
+ hx4 = self.rebnconv4(hx)
168
+ hx = self.pool4(hx4)
169
+
170
+ hx5 = self.rebnconv5(hx)
171
+
172
+ hx6 = self.rebnconv6(hx5)
173
+
174
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
175
+ hx5dup = _upsample_like(hx5d, hx4)
176
+
177
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
178
+ hx4dup = _upsample_like(hx4d, hx3)
179
+
180
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
181
+ hx3dup = _upsample_like(hx3d, hx2)
182
+
183
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
184
+ hx2dup = _upsample_like(hx2d, hx1)
185
+
186
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
187
+
188
+ return hx1d + hxin
189
+
190
+
191
+ class RSU5(nn.Module):
192
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
193
+ super(RSU5, self).__init__()
194
+
195
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
196
+
197
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
198
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
199
+
200
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
201
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
202
+
203
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
204
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
205
+
206
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
207
+
208
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
209
+
210
+ self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
211
+ self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
212
+ self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
213
+ self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
214
+
215
+ def forward(self, x):
216
+ hx = x
217
+ hxin = self.rebnconvin(hx)
218
+
219
+ hx1 = self.rebnconv1(hxin)
220
+ hx = self.pool1(hx1)
221
+
222
+ hx2 = self.rebnconv2(hx)
223
+ hx = self.pool2(hx2)
224
+
225
+ hx3 = self.rebnconv3(hx)
226
+ hx = self.pool3(hx3)
227
+
228
+ hx4 = self.rebnconv4(hx)
229
+
230
+ hx5 = self.rebnconv5(hx4)
231
+
232
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
233
+ hx4dup = _upsample_like(hx4d, hx3)
234
+
235
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
236
+ hx3dup = _upsample_like(hx3d, hx2)
237
+
238
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
239
+ hx2dup = _upsample_like(hx2d, hx1)
240
+
241
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
242
+
243
+ return hx1d + hxin
244
+
245
+
246
+ class RSU4(nn.Module):
247
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
248
+ super(RSU4, self).__init__()
249
+
250
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
251
+
252
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
253
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
254
+
255
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
256
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
257
+
258
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
259
+
260
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
261
+
262
+ self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
263
+ self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
264
+ self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
265
+
266
+ def forward(self, x):
267
+ hx = x
268
+ hxin = self.rebnconvin(hx)
269
+
270
+ hx1 = self.rebnconv1(hxin)
271
+ hx = self.pool1(hx1)
272
+
273
+ hx2 = self.rebnconv2(hx)
274
+ hx = self.pool2(hx2)
275
+
276
+ hx3 = self.rebnconv3(hx)
277
+
278
+ hx4 = self.rebnconv4(hx3)
279
+
280
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
281
+ hx3dup = _upsample_like(hx3d, hx2)
282
+
283
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
284
+ hx2dup = _upsample_like(hx2d, hx1)
285
+
286
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
287
+
288
+ return hx1d + hxin
289
+
290
+
291
+ class RSU4F(nn.Module):
292
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
293
+ super(RSU4F, self).__init__()
294
+
295
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
296
+
297
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
298
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
299
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
300
+
301
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
302
+
303
+ self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=4)
304
+ self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=2)
305
+ self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
306
+
307
+ def forward(self, x):
308
+ hx = x
309
+ hxin = self.rebnconvin(hx)
310
+
311
+ hx1 = self.rebnconv1(hxin)
312
+ hx2 = self.rebnconv2(hx1)
313
+ hx3 = self.rebnconv3(hx2)
314
+
315
+ hx4 = self.rebnconv4(hx3)
316
+
317
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
318
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
319
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
320
+
321
+ return hx1d + hxin
322
+
323
+
324
+ class U2NETP(nn.Module):
325
+ """Original U2NETP architecture - matches pretrained weights"""
326
 
327
  def __init__(self, in_ch=3, out_ch=1):
328
  super(U2NETP, self).__init__()
329
+
330
+ self.stage1 = RSU7(in_ch, 16, 64)
331
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
332
+
333
+ self.stage2 = RSU6(64, 16, 64)
334
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
335
+
336
+ self.stage3 = RSU5(64, 16, 64)
337
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
338
+
339
+ self.stage4 = RSU4(64, 16, 64)
340
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
341
+
342
+ self.stage5 = RSU4F(64, 16, 64)
343
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
344
+
345
+ self.stage6 = RSU4F(64, 16, 64)
346
+
347
+ # decoder
348
+ self.stage5d = RSU4F(128, 16, 64)
349
+ self.stage4d = RSU4(128, 16, 64)
350
+ self.stage3d = RSU5(128, 16, 64)
351
+ self.stage2d = RSU6(128, 16, 64)
352
+ self.stage1d = RSU7(128, 16, 64)
353
+
354
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
355
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
356
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
357
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
358
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
359
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
360
+
361
+ self.outconv = nn.Conv2d(6*out_ch, out_ch, 1)
362
+
 
 
 
 
 
 
 
 
363
  def forward(self, x):
364
  hx = x
365
+
366
+ # stage 1
367
  hx1 = self.stage1(hx)
368
  hx = self.pool12(hx1)
369
+
370
+ # stage 2
371
  hx2 = self.stage2(hx)
372
  hx = self.pool23(hx2)
373
+
374
+ # stage 3
375
  hx3 = self.stage3(hx)
376
  hx = self.pool34(hx3)
377
+
378
+ # stage 4
379
  hx4 = self.stage4(hx)
380
+ hx = self.pool45(hx4)
381
+
382
+ # stage 5
383
+ hx5 = self.stage5(hx)
384
+ hx = self.pool56(hx5)
385
+
386
+ # stage 6
387
+ hx6 = self.stage6(hx)
388
+ hx6up = _upsample_like(hx6, hx5)
389
+
390
+ # decoder
391
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
392
+ hx5dup = _upsample_like(hx5d, hx4)
393
+
394
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
395
+ hx4dup = _upsample_like(hx4d, hx3)
396
+
397
  hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
398
+ hx3dup = _upsample_like(hx3d, hx2)
399
+
400
  hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
401
+ hx2dup = _upsample_like(hx2d, hx1)
402
+
403
  hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
404
+
405
+ # side output
406
  d1 = self.side1(hx1d)
 
 
 
 
 
 
 
 
 
407
 
408
+ d2 = self.side2(hx2d)
409
+ d2 = _upsample_like(d2, d1)
410
+
411
+ d3 = self.side3(hx3d)
412
+ d3 = _upsample_like(d3, d1)
413
+
414
+ d4 = self.side4(hx4d)
415
+ d4 = _upsample_like(d4, d1)
416
+
417
+ d5 = self.side5(hx5d)
418
+ d5 = _upsample_like(d5, d1)
419
+
420
+ d6 = self.side6(hx6)
421
+ d6 = _upsample_like(d6, d1)
422
+
423
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
424
+
425
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
426
+
427
+
428
+ # ==================== SEGMENTER CLASS ====================
429
 
430
  class BinarySegmenter:
431
  """
432
+ Professional binary segmentation tool using U2NETP.
433
 
434
  Args:
435
  model_type: Choice of segmentation model
 
467
  logger.info(f"{self.model_type} loaded successfully")
468
 
469
  def _load_u2netp(self):
470
+ """Load U2NETP model with original architecture"""
471
  self.model = U2NETP(3, 1)
472
 
473
  # Try to load pretrained weights