phucd commited on
Commit
ca8fa7a
·
1 Parent(s): a6314dd

Initial Commit

Browse files
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import matplotlib.pyplot as plt
3
+ import cv2
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import gradio as gr
9
+ from seg import U2NETP
10
+
11
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ # Image processing utilities
14
+ def load_image(path: str):
15
+ """ Loads an image from the specified path and converts it to RGB format. """
16
+ img = cv2.imread(path)
17
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
18
+ return img / 255.0
19
+
20
+ def save_image(image: np.ndarray, path: str):
21
+ """ Saves an image to the specified path. """
22
+ img = (image * 255).astype(np.uint8)
23
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
24
+ cv2.imwrite(path, img)
25
+
26
+ # Document Segmentation Model
27
+ class U2NETP_DocSeg(nn.Module):
28
+ def __init__(self, num_classes=1):
29
+ super(U2NETP_DocSeg, self).__init__()
30
+ self.u2netp = U2NETP(out_ch=num_classes)
31
+
32
+ def forward(self, x):
33
+ mask, *_ = self.u2netp(x)
34
+ return mask
35
+
36
+ # Initialize the document segmentation model
37
+ docseg = U2NETP_DocSeg(num_classes=1).to(DEVICE)
38
+ # Load pretrained weights
39
+ docseg_weight_path = './weights/u2netp_docseg_epoch_225_date_2026-01-02.pth'
40
+ checkpoint = torch.load(docseg_weight_path)
41
+ docseg.load_state_dict(checkpoint[f"model_state_dict"])
42
+ docseg.eval()
43
+
44
+ # Get segmentation mask
45
+ def get_mask(image, confidence=0.5):
46
+ org_shape = image.shape[:2]
47
+ image_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0).to(DEVICE)
48
+ image_tensor = F.interpolate(image_tensor, size=(288, 288), mode='bilinear')
49
+ with torch.inference_mode(): # faster than no_grad
50
+ mask = docseg(image_tensor)
51
+ mask = (mask > confidence).float()
52
+ mask = F.interpolate(mask, size=org_shape, mode='bilinear')
53
+ return mask[0, 0] # keep tensor
54
+
55
+ def overlay_mask(image, mask):
56
+ image = torch.from_numpy(image).float().to(DEVICE)
57
+ red = torch.tensor([1.0, 0, 0], device=DEVICE).view(1, 3, 1, 1)
58
+ mask = mask.unsqueeze(0) # (1, H, W)
59
+ mask = mask.unsqueeze(0) # (1, 1, H, W)
60
+ overlay = image.permute(2, 0, 1).unsqueeze(0)
61
+ overlay = torch.where(mask > 0, red, overlay)
62
+ blended = 0.7 * image.permute(2, 0, 1).unsqueeze(0) + 0.3 * overlay
63
+ return blended[0].permute(1, 2, 0).cpu().numpy()
64
+
65
+ def segment_image(image):
66
+ """ Gradio function to segment input image and return overlay. """
67
+ image = image.astype(np.float32) / 255.0 # Normalize to [0, 1]
68
+ mask = get_mask(image, confidence=0.5)
69
+ overlayed_image = overlay_mask(image, mask)
70
+ yield overlayed_image
71
+
72
+ with gr.Blocks() as demo:
73
+ gr.Markdown("## Real-time Document Segmentation")
74
+ with gr.Row():
75
+ input_image = gr.Image(label="Input Image", type="numpy")
76
+ output_image = gr.Image(label="Segmentation Overlay", type="numpy")
77
+
78
+ input_image.change(segment_image, inputs=input_image, outputs=output_image)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ matplotlib
4
+ opencv-python
5
+ gradio
6
+ torchvision
seg.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+
8
+ class sobel_net(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.conv_opx = nn.Conv2d(1, 1, 3, bias=False)
12
+ self.conv_opy = nn.Conv2d(1, 1, 3, bias=False)
13
+ sobel_kernelx = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype='float32').reshape((1, 1, 3, 3))
14
+ sobel_kernely = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype='float32').reshape((1, 1, 3, 3))
15
+ self.conv_opx.weight.data = torch.from_numpy(sobel_kernelx)
16
+ self.conv_opy.weight.data = torch.from_numpy(sobel_kernely)
17
+
18
+ for p in self.parameters():
19
+ p.requires_grad = False
20
+
21
+ def forward(self, im): # input rgb
22
+ x = (0.299 * im[:, 0, :, :] + 0.587 * im[:, 1, :, :] + 0.114 * im[:, 2, :, :]).unsqueeze(1) # rgb2gray
23
+ gradx = self.conv_opx(x)
24
+ grady = self.conv_opy(x)
25
+
26
+ x = (gradx ** 2 + grady ** 2) ** 0.5
27
+ x = (x - x.min()) / (x.max() - x.min())
28
+ x = F.pad(x, (1, 1, 1, 1))
29
+
30
+ x = torch.cat([im, x], dim=1)
31
+ return x
32
+
33
+
34
+ class REBNCONV(nn.Module):
35
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
36
+ super(REBNCONV, self).__init__()
37
+
38
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
39
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
40
+ self.relu_s1 = nn.ReLU(inplace=True)
41
+
42
+ def forward(self, x):
43
+ hx = x
44
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
45
+
46
+ return xout
47
+
48
+
49
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
50
+ def _upsample_like(src, tar):
51
+ src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
52
+ return src
53
+
54
+
55
+ ### RSU-7 ###
56
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
57
+
58
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
59
+ super(RSU7, self).__init__()
60
+
61
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
62
+
63
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
64
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
65
+
66
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
67
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
68
+
69
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
70
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
71
+
72
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
73
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
74
+
75
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
76
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
77
+
78
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
79
+
80
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
81
+
82
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
83
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
84
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
85
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
86
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
87
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
88
+
89
+ def forward(self, x):
90
+ hx = x
91
+ hxin = self.rebnconvin(hx)
92
+
93
+ hx1 = self.rebnconv1(hxin)
94
+ hx = self.pool1(hx1)
95
+
96
+ hx2 = self.rebnconv2(hx)
97
+ hx = self.pool2(hx2)
98
+
99
+ hx3 = self.rebnconv3(hx)
100
+ hx = self.pool3(hx3)
101
+
102
+ hx4 = self.rebnconv4(hx)
103
+ hx = self.pool4(hx4)
104
+
105
+ hx5 = self.rebnconv5(hx)
106
+ hx = self.pool5(hx5)
107
+
108
+ hx6 = self.rebnconv6(hx)
109
+
110
+ hx7 = self.rebnconv7(hx6)
111
+
112
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
113
+ hx6dup = _upsample_like(hx6d, hx5)
114
+
115
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
116
+ hx5dup = _upsample_like(hx5d, hx4)
117
+
118
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
119
+ hx4dup = _upsample_like(hx4d, hx3)
120
+
121
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
122
+ hx3dup = _upsample_like(hx3d, hx2)
123
+
124
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
125
+ hx2dup = _upsample_like(hx2d, hx1)
126
+
127
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
128
+
129
+ return hx1d + hxin
130
+
131
+
132
+ ### RSU-6 ###
133
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
134
+
135
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
136
+ super(RSU6, self).__init__()
137
+
138
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
139
+
140
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
141
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
142
+
143
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
144
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
145
+
146
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
147
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
148
+
149
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
150
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
151
+
152
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
153
+
154
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
155
+
156
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
157
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
158
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
159
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
160
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
161
+
162
+ def forward(self, x):
163
+ hx = x
164
+
165
+ hxin = self.rebnconvin(hx)
166
+
167
+ hx1 = self.rebnconv1(hxin)
168
+ hx = self.pool1(hx1)
169
+
170
+ hx2 = self.rebnconv2(hx)
171
+ hx = self.pool2(hx2)
172
+
173
+ hx3 = self.rebnconv3(hx)
174
+ hx = self.pool3(hx3)
175
+
176
+ hx4 = self.rebnconv4(hx)
177
+ hx = self.pool4(hx4)
178
+
179
+ hx5 = self.rebnconv5(hx)
180
+
181
+ hx6 = self.rebnconv6(hx5)
182
+
183
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
184
+ hx5dup = _upsample_like(hx5d, hx4)
185
+
186
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
187
+ hx4dup = _upsample_like(hx4d, hx3)
188
+
189
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
190
+ hx3dup = _upsample_like(hx3d, hx2)
191
+
192
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
193
+ hx2dup = _upsample_like(hx2d, hx1)
194
+
195
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
196
+
197
+ return hx1d + hxin
198
+
199
+
200
+ ### RSU-5 ###
201
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
202
+
203
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
204
+ super(RSU5, self).__init__()
205
+
206
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
207
+
208
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
209
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
210
+
211
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
212
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
213
+
214
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
215
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
216
+
217
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
218
+
219
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
220
+
221
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
222
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
223
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
224
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
225
+
226
+ def forward(self, x):
227
+ hx = x
228
+
229
+ hxin = self.rebnconvin(hx)
230
+
231
+ hx1 = self.rebnconv1(hxin)
232
+ hx = self.pool1(hx1)
233
+
234
+ hx2 = self.rebnconv2(hx)
235
+ hx = self.pool2(hx2)
236
+
237
+ hx3 = self.rebnconv3(hx)
238
+ hx = self.pool3(hx3)
239
+
240
+ hx4 = self.rebnconv4(hx)
241
+
242
+ hx5 = self.rebnconv5(hx4)
243
+
244
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
245
+ hx4dup = _upsample_like(hx4d, hx3)
246
+
247
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
248
+ hx3dup = _upsample_like(hx3d, hx2)
249
+
250
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
251
+ hx2dup = _upsample_like(hx2d, hx1)
252
+
253
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
254
+
255
+ return hx1d + hxin
256
+
257
+
258
+ ### RSU-4 ###
259
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
260
+
261
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
262
+ super(RSU4, self).__init__()
263
+
264
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
265
+
266
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
267
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
268
+
269
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
270
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
271
+
272
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
273
+
274
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
275
+
276
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
277
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
278
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
279
+
280
+ def forward(self, x):
281
+ hx = x
282
+
283
+ hxin = self.rebnconvin(hx)
284
+
285
+ hx1 = self.rebnconv1(hxin)
286
+ hx = self.pool1(hx1)
287
+
288
+ hx2 = self.rebnconv2(hx)
289
+ hx = self.pool2(hx2)
290
+
291
+ hx3 = self.rebnconv3(hx)
292
+
293
+ hx4 = self.rebnconv4(hx3)
294
+
295
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
296
+ hx3dup = _upsample_like(hx3d, hx2)
297
+
298
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
299
+ hx2dup = _upsample_like(hx2d, hx1)
300
+
301
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
302
+
303
+ return hx1d + hxin
304
+
305
+
306
+ ### RSU-4F ###
307
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
308
+
309
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
310
+ super(RSU4F, self).__init__()
311
+
312
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
313
+
314
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
315
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
316
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
317
+
318
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
319
+
320
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
321
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
322
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
323
+
324
+ def forward(self, x):
325
+ hx = x
326
+
327
+ hxin = self.rebnconvin(hx)
328
+
329
+ hx1 = self.rebnconv1(hxin)
330
+ hx2 = self.rebnconv2(hx1)
331
+ hx3 = self.rebnconv3(hx2)
332
+
333
+ hx4 = self.rebnconv4(hx3)
334
+
335
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
336
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
337
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
338
+
339
+ return hx1d + hxin
340
+
341
+
342
+ ##### U^2-Net ####
343
+ class U2NET(nn.Module):
344
+ def __init__(self, in_ch=3, out_ch=1):
345
+ super(U2NET, self).__init__()
346
+ self.edge = sobel_net()
347
+
348
+ self.stage1 = RSU7(in_ch, 32, 64)
349
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
350
+
351
+ self.stage2 = RSU6(64, 32, 128)
352
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
353
+
354
+ self.stage3 = RSU5(128, 64, 256)
355
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
356
+
357
+ self.stage4 = RSU4(256, 128, 512)
358
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
359
+
360
+ self.stage5 = RSU4F(512, 256, 512)
361
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
362
+
363
+ self.stage6 = RSU4F(512, 256, 512)
364
+
365
+ # decoder
366
+ self.stage5d = RSU4F(1024, 256, 512)
367
+ self.stage4d = RSU4(1024, 128, 256)
368
+ self.stage3d = RSU5(512, 64, 128)
369
+ self.stage2d = RSU6(256, 32, 64)
370
+ self.stage1d = RSU7(128, 16, 64)
371
+
372
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
373
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
374
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
375
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
376
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
377
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
378
+
379
+ self.outconv = nn.Conv2d(6, out_ch, 1)
380
+
381
+ def forward(self, x):
382
+ x = self.edge(x)
383
+ hx = x
384
+
385
+ # stage 1
386
+ hx1 = self.stage1(hx)
387
+ hx = self.pool12(hx1)
388
+
389
+ # stage 2
390
+ hx2 = self.stage2(hx)
391
+ hx = self.pool23(hx2)
392
+
393
+ # stage 3
394
+ hx3 = self.stage3(hx)
395
+ hx = self.pool34(hx3)
396
+
397
+ # stage 4
398
+ hx4 = self.stage4(hx)
399
+ hx = self.pool45(hx4)
400
+
401
+ # stage 5
402
+ hx5 = self.stage5(hx)
403
+ hx = self.pool56(hx5)
404
+
405
+ # stage 6
406
+ hx6 = self.stage6(hx)
407
+ hx6up = _upsample_like(hx6, hx5)
408
+
409
+ # -------------------- decoder --------------------
410
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
411
+ hx5dup = _upsample_like(hx5d, hx4)
412
+
413
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
414
+ hx4dup = _upsample_like(hx4d, hx3)
415
+
416
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
417
+ hx3dup = _upsample_like(hx3d, hx2)
418
+
419
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
420
+ hx2dup = _upsample_like(hx2d, hx1)
421
+
422
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
423
+
424
+ # side output
425
+ d1 = self.side1(hx1d)
426
+
427
+ d2 = self.side2(hx2d)
428
+ d2 = _upsample_like(d2, d1)
429
+
430
+ d3 = self.side3(hx3d)
431
+ d3 = _upsample_like(d3, d1)
432
+
433
+ d4 = self.side4(hx4d)
434
+ d4 = _upsample_like(d4, d1)
435
+
436
+ d5 = self.side5(hx5d)
437
+ d5 = _upsample_like(d5, d1)
438
+
439
+ d6 = self.side6(hx6)
440
+ d6 = _upsample_like(d6, d1)
441
+
442
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
443
+
444
+ return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(
445
+ d4), torch.sigmoid(d5), torch.sigmoid(d6)
446
+
447
+ ### U^2-Net small ###
448
+ class U2NETP(nn.Module):
449
+
450
+ def __init__(self, in_ch=3, out_ch=1):
451
+ super(U2NETP, self).__init__()
452
+
453
+ self.stage1 = RSU7(in_ch, 16, 64)
454
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
455
+
456
+ self.stage2 = RSU6(64, 16, 64)
457
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
458
+
459
+ self.stage3 = RSU5(64, 16, 64)
460
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
461
+
462
+ self.stage4 = RSU4(64, 16, 64)
463
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
464
+
465
+ self.stage5 = RSU4F(64, 16, 64)
466
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
467
+
468
+ self.stage6 = RSU4F(64, 16, 64)
469
+
470
+ # decoder
471
+ self.stage5d = RSU4F(128, 16, 64)
472
+ self.stage4d = RSU4(128, 16, 64)
473
+ self.stage3d = RSU5(128, 16, 64)
474
+ self.stage2d = RSU6(128, 16, 64)
475
+ self.stage1d = RSU7(128, 16, 64)
476
+
477
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
478
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
479
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
480
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
481
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
482
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
483
+
484
+ self.outconv = nn.Conv2d(6, out_ch, 1)
485
+
486
+ def forward(self, x):
487
+ hx = x
488
+
489
+ # stage 1
490
+ hx1 = self.stage1(hx)
491
+ hx = self.pool12(hx1)
492
+
493
+ # stage 2
494
+ hx2 = self.stage2(hx)
495
+ hx = self.pool23(hx2)
496
+
497
+ # stage 3
498
+ hx3 = self.stage3(hx)
499
+ hx = self.pool34(hx3)
500
+
501
+ # stage 4
502
+ hx4 = self.stage4(hx)
503
+ hx = self.pool45(hx4)
504
+
505
+ # stage 5
506
+ hx5 = self.stage5(hx)
507
+ hx = self.pool56(hx5)
508
+
509
+ # stage 6
510
+ hx6 = self.stage6(hx)
511
+ hx6up = _upsample_like(hx6, hx5)
512
+
513
+ # decoder
514
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
515
+ hx5dup = _upsample_like(hx5d, hx4)
516
+
517
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
518
+ hx4dup = _upsample_like(hx4d, hx3)
519
+
520
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
521
+ hx3dup = _upsample_like(hx3d, hx2)
522
+
523
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
524
+ hx2dup = _upsample_like(hx2d, hx1)
525
+
526
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
527
+
528
+ # side output
529
+ d1 = self.side1(hx1d)
530
+
531
+ d2 = self.side2(hx2d)
532
+ d2 = _upsample_like(d2, d1)
533
+
534
+ d3 = self.side3(hx3d)
535
+ d3 = _upsample_like(d3, d1)
536
+
537
+ d4 = self.side4(hx4d)
538
+ d4 = _upsample_like(d4, d1)
539
+
540
+ d5 = self.side5(hx5d)
541
+ d5 = _upsample_like(d5, d1)
542
+
543
+ d6 = self.side6(hx6)
544
+ d6 = _upsample_like(d6, d1)
545
+
546
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
547
+
548
+ return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(
549
+ d4), torch.sigmoid(d5), torch.sigmoid(d6)
550
+
551
+ class ClassifierHead(nn.Module):
552
+ def __init__(self, in_channels=64, channels=[512, 128], mode='avg_pool'):
553
+ super(ClassifierHead, self).__init__()
554
+ self.linears = nn.ModuleList()
555
+ for i, c in enumerate(channels):
556
+ if i == 0:
557
+ self.linears.append(nn.Linear(in_channels, c))
558
+ else:
559
+ self.linears.append(nn.Linear(channels[i-1], c))
560
+ self.cls = nn.Linear(channels[-1], 1)
561
+ self.available_modes = ['avg_pool', 'max_pool', 'flatten']
562
+ if mode not in self.available_modes:
563
+ raise ValueError("Mode must be one of: {}".format(self.available_modes))
564
+ self.mode = mode
565
+
566
+ def forward(self, x):
567
+ if self.mode == 'avg_pool':
568
+ x = F.adaptive_avg_pool2d(x, (1, 1))
569
+ elif self.mode == 'max_pool':
570
+ x = F.adaptive_max_pool2d(x, (1, 1))
571
+ elif self.mode == 'flatten':
572
+ x = torch.flatten(x, 1)
573
+ else:
574
+ raise ValueError("Unsupported mode: {}".format(self.mode))
575
+ # print("x shape after pooling:", x.shape)
576
+ x = x.view(x.size(0), -1) # Flatten the tensor
577
+ for linear in self.linears:
578
+ x = F.relu(linear(x))
579
+ x = self.cls(x)
580
+ return x
581
+
582
+ class U2NETP_v2(nn.Module):
583
+
584
+ def __init__(self, in_ch=3, out_ch=1):
585
+ super(U2NETP_v2, self).__init__()
586
+
587
+ self.stage1 = RSU7(in_ch, 16, 64)
588
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
589
+
590
+ self.stage2 = RSU6(64, 16, 64)
591
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
592
+
593
+ self.stage3 = RSU5(64, 16, 64)
594
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
595
+
596
+ self.stage4 = RSU4(64, 16, 64)
597
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
598
+
599
+ self.stage5 = RSU4F(64, 16, 64)
600
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
601
+
602
+ self.stage6 = RSU4F(64, 16, 64)
603
+
604
+ # decoder
605
+ self.stage5d = RSU4F(128, 16, 64)
606
+ self.stage4d = RSU4(128, 16, 64)
607
+ self.stage3d = RSU5(128, 16, 64)
608
+ self.stage2d = RSU6(128, 16, 64)
609
+ self.stage1d = RSU7(128, 16, 64)
610
+
611
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
612
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
613
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
614
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
615
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
616
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
617
+
618
+ self.outconv = nn.Conv2d(out_ch * 6, out_ch, 1)
619
+
620
+ def forward(self, x):
621
+ hx = x
622
+
623
+ # stage 1
624
+ hx1 = self.stage1(hx)
625
+ hx = self.pool12(hx1)
626
+
627
+ # stage 2
628
+ hx2 = self.stage2(hx)
629
+ hx = self.pool23(hx2)
630
+
631
+ # stage 3
632
+ hx3 = self.stage3(hx)
633
+ hx = self.pool34(hx3)
634
+
635
+ # stage 4
636
+ hx4 = self.stage4(hx)
637
+ hx = self.pool45(hx4)
638
+
639
+ # stage 5
640
+ hx5 = self.stage5(hx)
641
+ hx = self.pool56(hx5)
642
+
643
+ # stage 6
644
+ hx6 = self.stage6(hx)
645
+ hx6up = _upsample_like(hx6, hx5)
646
+
647
+ # decoder
648
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
649
+ hx5dup = _upsample_like(hx5d, hx4)
650
+
651
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
652
+ hx4dup = _upsample_like(hx4d, hx3)
653
+
654
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
655
+ hx3dup = _upsample_like(hx3d, hx2)
656
+
657
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
658
+ hx2dup = _upsample_like(hx2d, hx1)
659
+
660
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
661
+
662
+ # side output
663
+ d1 = self.side1(hx1d)
664
+
665
+ d2 = self.side2(hx2d)
666
+ d2 = _upsample_like(d2, d1)
667
+
668
+ d3 = self.side3(hx3d)
669
+ d3 = _upsample_like(d3, d1)
670
+
671
+ d4 = self.side4(hx4d)
672
+ d4 = _upsample_like(d4, d1)
673
+
674
+ d5 = self.side5(hx5d)
675
+ d5 = _upsample_like(d5, d1)
676
+
677
+ d6 = self.side6(hx6)
678
+ d6 = _upsample_like(d6, d1)
679
+ # print(d1.shape, d2.shape, d3.shape, d4.shape, d5.shape, d6.shape)
680
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
681
+
682
+ return d0, hx6
weights/seg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb79fdec55a5ed435dc74d8112aa9285d8213bae475022f711c709744fb19dd4
3
+ size 4715923
weights/u2netp_docseg_epoch_225_date_2026-01-02.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b39f3eb35f0985dc168eea21d9007a8467a79a3f80baa668f2b9ff6112f31ef6
3
+ size 14344719