qqwjq1981 commited on
Commit
d859cd0
·
verified ·
1 Parent(s): 0634acd

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils/u2net_detector.py +182 -0
  2. utils/u2netp.pth +3 -0
utils/u2net_detector.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ from shapely.geometry import Polygon
8
+
9
+ # -------------------------------------------------------------------
10
+ # U²-Netp Model Definition (lightweight 4.7MB)
11
+ # -------------------------------------------------------------------
12
+ # Source: https://github.com/xuebinqin/U-2-Net
13
+ # We include only the necessary modules.
14
+
15
+ class REBNCONV(torch.nn.Module):
16
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
17
+ super(REBNCONV, self).__init__()
18
+ self.conv_s1 = torch.nn.Conv2d(
19
+ in_ch, out_ch, 3, padding=1*dirate, dilation=1*dirate
20
+ )
21
+ self.relu_s1 = torch.nn.ReLU(inplace=True)
22
+
23
+ def forward(self, x):
24
+ hx = x
25
+ hx = self.relu_s1(self.conv_s1(hx))
26
+ return hx
27
+
28
+
29
+ class RSU4F(torch.nn.Module):
30
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
31
+ super(RSU4F, self).__init__()
32
+
33
+ self.rebnconvin = REBNCONV(in_ch, out_ch)
34
+
35
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch)
36
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch)
37
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch)
38
+
39
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch)
40
+
41
+ def forward(self, x):
42
+ hx = x
43
+ hxin = self.rebnconvin(hx)
44
+
45
+ hx1 = self.rebnconv1(hxin)
46
+ hx2 = self.rebnconv2(hx1)
47
+ hx3 = self.rebnconv3(hx2)
48
+ hx4 = self.rebnconv4(hx3)
49
+
50
+ return hxin + hx4
51
+
52
+
53
+ class U2NETP(torch.nn.Module):
54
+ def __init__(self, in_ch=3, out_ch=1):
55
+ super(U2NETP, self).__init__()
56
+
57
+ self.stage1 = RSU4F(in_ch, 12, 64)
58
+ self.stage2 = RSU4F(64, 12, 64)
59
+ self.stage3 = RSU4F(64, 12, 64)
60
+ self.stage4 = RSU4F(64, 12, 64)
61
+ self.stage5 = RSU4F(64, 12, 64)
62
+ self.stage6 = RSU4F(64, 12, 64)
63
+ self.side6 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
64
+
65
+ def forward(self, x):
66
+ hx1 = self.stage1(x)
67
+ hx2 = self.stage2(hx1)
68
+ hx3 = self.stage3(hx2)
69
+
70
+ hx6 = self.stage6(hx3)
71
+ d6 = self.side6(hx6)
72
+
73
+ return torch.sigmoid(d6)
74
+
75
+
76
+ # -------------------------------------------------------------------
77
+ # Load Model (once)
78
+ # -------------------------------------------------------------------
79
+ MODEL_PATH = os.path.join(os.path.dirname(__file__), "u2netp.pth")
80
+
81
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
+ _u2net_model = None
83
+
84
+
85
+ def load_u2netp():
86
+ global _u2net_model
87
+ if _u2net_model is None:
88
+ print("🔄 Loading U²-Netp model…")
89
+ model = U2NETP()
90
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=_device))
91
+ model.to(_device)
92
+ model.eval()
93
+ _u2net_model = model
94
+ print("✅ U²-Netp Loaded.")
95
+ return _u2net_model
96
+
97
+
98
+ # -------------------------------------------------------------------
99
+ # Preprocessing
100
+ # -------------------------------------------------------------------
101
+ def preprocess(img_pil, size=320):
102
+ img = img_pil.convert("RGB")
103
+ img = img.resize((size, size), Image.BILINEAR)
104
+ arr = np.array(img).astype(np.float32) / 255.0
105
+
106
+ tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
107
+ return tensor.to(_device), img_pil.size
108
+
109
+
110
+ # -------------------------------------------------------------------
111
+ # Postprocessing → polygon conversion
112
+ # -------------------------------------------------------------------
113
+ def mask_to_polygons(mask, min_area=300):
114
+ """
115
+ Convert binary mask → list of polygons (list[list[(x,y)]])
116
+ """
117
+ mask = (mask * 255).astype("uint8")
118
+
119
+ # cleanup
120
+ kernel = np.ones((5,5), np.uint8)
121
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
122
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
123
+
124
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
125
+
126
+ polys = []
127
+
128
+ for cnt in contours:
129
+ area = cv2.contourArea(cnt)
130
+ if area < min_area:
131
+ continue
132
+
133
+ eps = 0.01 * cv2.arcLength(cnt, True)
134
+ approx = cv2.approxPolyDP(cnt, eps, True)
135
+
136
+ poly = [(int(p[0][0]), int(p[0][1])) for p in approx]
137
+ polys.append(poly)
138
+
139
+ return polys
140
+
141
+
142
+ def resize_polygons(polygons, orig_w, orig_h, proc_size=320):
143
+ """Scale polygons back to original image size"""
144
+ scaled = []
145
+ for poly in polygons:
146
+ scaled.append([
147
+ (
148
+ int(x * orig_w / proc_size),
149
+ int(y * orig_h / proc_size)
150
+ )
151
+ for (x, y) in poly
152
+ ])
153
+ return scaled
154
+
155
+
156
+ # -------------------------------------------------------------------
157
+ # Main Bubble Detection Function
158
+ # -------------------------------------------------------------------
159
+ def detect_bubbles_u2net(img_pil, min_area=300):
160
+ """
161
+ Return list of bubble polygons from U²-Net saliency segmentation.
162
+ """
163
+ model = load_u2netp()
164
+
165
+ tensor, orig_size = preprocess(img_pil)
166
+ orig_w, orig_h = img_pil.size
167
+
168
+ with torch.no_grad():
169
+ pred = model(tensor)[0, 0].cpu().numpy()
170
+
171
+ # Normalize & threshold
172
+ pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
173
+ mask = (pred > 0.4).astype(np.uint8)
174
+
175
+ # polygons from mask
176
+ polys = mask_to_polygons(mask, min_area=min_area)
177
+
178
+ # rescale to original image size
179
+ polys = resize_polygons(polys, orig_w, orig_h)
180
+
181
+ print(f"🧠 U²-Net bubbles detected: {len(polys)}")
182
+ return polys
utils/u2netp.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7567cde013fb64813973ce6e1ecc25a80c05c3ca7adbc5a54f3c3d90991b854
3
+ size 4683258