Ayesha-Majeed commited on
Commit
6c7e42f
·
verified ·
1 Parent(s): b93281d

Update binary_segmentation.py

Browse files
Files changed (1) hide show
  1. binary_segmentation.py +39 -20
binary_segmentation.py CHANGED
@@ -17,7 +17,6 @@ import torch
17
  from PIL import Image
18
  from torchvision import transforms
19
  import cv2
20
- from u2net import U2NETP
21
 
22
  # Configure logging
23
  logging.basicConfig(
@@ -32,12 +31,15 @@ logger.info(f"Using device: {DEVICE}")
32
 
33
 
34
  class U2NETP(torch.nn.Module):
35
- """U2-Net Portrait (U2NETP) - Lightweight segmentation model"""
 
 
 
36
 
37
  def __init__(self, in_ch=3, out_ch=1):
38
  super(U2NETP, self).__init__()
39
 
40
- # Encoder
41
  self.stage1 = self._make_stage(in_ch, 16, 64)
42
  self.pool12 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
43
 
@@ -48,25 +50,31 @@ class U2NETP(torch.nn.Module):
48
  self.pool34 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
49
 
50
  self.stage4 = self._make_stage(64, 16, 64)
 
51
 
52
- # Bridge
53
  self.stage5 = self._make_stage(64, 16, 64)
 
54
 
55
- # Decoder
 
 
 
 
56
  self.stage4d = self._make_stage(128, 16, 64)
57
  self.stage3d = self._make_stage(128, 16, 64)
58
  self.stage2d = self._make_stage(128, 16, 64)
59
  self.stage1d = self._make_stage(128, 16, 64)
60
 
61
- # Side outputs
62
  self.side1 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
63
  self.side2 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
64
  self.side3 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
65
  self.side4 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
66
  self.side5 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
 
67
 
68
- # Output fusion
69
- self.outconv = torch.nn.Conv2d(5 * out_ch, out_ch, 1)
70
 
71
  def _make_stage(self, in_ch, mid_ch, out_ch):
72
  return torch.nn.Sequential(
@@ -81,7 +89,7 @@ class U2NETP(torch.nn.Module):
81
  def forward(self, x):
82
  hx = x
83
 
84
- # Encoder
85
  hx1 = self.stage1(hx)
86
  hx = self.pool12(hx1)
87
 
@@ -92,18 +100,28 @@ class U2NETP(torch.nn.Module):
92
  hx = self.pool34(hx3)
93
 
94
  hx4 = self.stage4(hx)
95
- hx5 = self.stage5(hx4)
96
 
97
- # Decoder
98
- hx4d = self.stage4d(torch.cat((hx5, hx4), 1))
99
- hx4dup = torch.nn.functional.interpolate(hx4d, scale_factor=2, mode='bilinear', align_corners=True)
100
 
 
 
 
 
 
 
 
 
 
 
 
101
  hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
102
- hx3dup = torch.nn.functional.interpolate(hx3d, scale_factor=2, mode='bilinear', align_corners=True)
103
 
 
104
  hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
105
- hx2dup = torch.nn.functional.interpolate(hx2d, scale_factor=2, mode='bilinear', align_corners=True)
106
 
 
107
  hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
108
 
109
  # Side outputs
@@ -111,12 +129,13 @@ class U2NETP(torch.nn.Module):
111
  d2 = torch.nn.functional.interpolate(self.side2(hx2d), size=d1.shape[2:], mode='bilinear', align_corners=True)
112
  d3 = torch.nn.functional.interpolate(self.side3(hx3d), size=d1.shape[2:], mode='bilinear', align_corners=True)
113
  d4 = torch.nn.functional.interpolate(self.side4(hx4d), size=d1.shape[2:], mode='bilinear', align_corners=True)
114
- d5 = torch.nn.functional.interpolate(self.side5(hx5), size=d1.shape[2:], mode='bilinear', align_corners=True)
 
115
 
116
- # Fusion
117
- d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5), 1))
118
 
119
- return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5)
120
 
121
 
122
  class BinarySegmenter:
@@ -396,4 +415,4 @@ if __name__ == "__main__":
396
  model_type=args.model,
397
  threshold=args.threshold,
398
  save_rgba=(args.format == "rgba")
399
- )
 
17
  from PIL import Image
18
  from torchvision import transforms
19
  import cv2
 
20
 
21
  # Configure logging
22
  logging.basicConfig(
 
31
 
32
 
33
  class U2NETP(torch.nn.Module):
34
+ """U2-Net Portrait (U2NETP) - Lightweight segmentation model
35
+
36
+ Fixed to match pretrained weights architecture with 6 stages/side outputs.
37
+ """
38
 
39
  def __init__(self, in_ch=3, out_ch=1):
40
  super(U2NETP, self).__init__()
41
 
42
+ # Encoder (6 stages to match pretrained weights)
43
  self.stage1 = self._make_stage(in_ch, 16, 64)
44
  self.pool12 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
45
 
 
50
  self.pool34 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
51
 
52
  self.stage4 = self._make_stage(64, 16, 64)
53
+ self.pool45 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
54
 
 
55
  self.stage5 = self._make_stage(64, 16, 64)
56
+ self.pool56 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
57
 
58
+ # Bridge
59
+ self.stage6 = self._make_stage(64, 16, 64)
60
+
61
+ # Decoder (5 decoder stages)
62
+ self.stage5d = self._make_stage(128, 16, 64)
63
  self.stage4d = self._make_stage(128, 16, 64)
64
  self.stage3d = self._make_stage(128, 16, 64)
65
  self.stage2d = self._make_stage(128, 16, 64)
66
  self.stage1d = self._make_stage(128, 16, 64)
67
 
68
+ # Side outputs (6 outputs to match pretrained weights)
69
  self.side1 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
70
  self.side2 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
71
  self.side3 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
72
  self.side4 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
73
  self.side5 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
74
+ self.side6 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
75
 
76
+ # Output fusion (6 channels now)
77
+ self.outconv = torch.nn.Conv2d(6 * out_ch, out_ch, 1)
78
 
79
  def _make_stage(self, in_ch, mid_ch, out_ch):
80
  return torch.nn.Sequential(
 
89
  def forward(self, x):
90
  hx = x
91
 
92
+ # Encoder with skip connections
93
  hx1 = self.stage1(hx)
94
  hx = self.pool12(hx1)
95
 
 
100
  hx = self.pool34(hx3)
101
 
102
  hx4 = self.stage4(hx)
103
+ hx = self.pool45(hx4)
104
 
105
+ hx5 = self.stage5(hx)
106
+ hx = self.pool56(hx5)
 
107
 
108
+ # Bridge
109
+ hx6 = self.stage6(hx)
110
+
111
+ # Decoder with skip connections
112
+ hx6up = torch.nn.functional.interpolate(hx6, size=hx5.shape[2:], mode='bilinear', align_corners=True)
113
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
114
+
115
+ hx5dup = torch.nn.functional.interpolate(hx5d, size=hx4.shape[2:], mode='bilinear', align_corners=True)
116
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
117
+
118
+ hx4dup = torch.nn.functional.interpolate(hx4d, size=hx3.shape[2:], mode='bilinear', align_corners=True)
119
  hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
 
120
 
121
+ hx3dup = torch.nn.functional.interpolate(hx3d, size=hx2.shape[2:], mode='bilinear', align_corners=True)
122
  hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
 
123
 
124
+ hx2dup = torch.nn.functional.interpolate(hx2d, size=hx1.shape[2:], mode='bilinear', align_corners=True)
125
  hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
126
 
127
  # Side outputs
 
129
  d2 = torch.nn.functional.interpolate(self.side2(hx2d), size=d1.shape[2:], mode='bilinear', align_corners=True)
130
  d3 = torch.nn.functional.interpolate(self.side3(hx3d), size=d1.shape[2:], mode='bilinear', align_corners=True)
131
  d4 = torch.nn.functional.interpolate(self.side4(hx4d), size=d1.shape[2:], mode='bilinear', align_corners=True)
132
+ d5 = torch.nn.functional.interpolate(self.side5(hx5d), size=d1.shape[2:], mode='bilinear', align_corners=True)
133
+ d6 = torch.nn.functional.interpolate(self.side6(hx6), size=d1.shape[2:], mode='bilinear', align_corners=True)
134
 
135
+ # Fusion (6 side outputs now)
136
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
137
 
138
+ return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)
139
 
140
 
141
  class BinarySegmenter:
 
415
  model_type=args.model,
416
  threshold=args.threshold,
417
  save_rgba=(args.format == "rgba")
418
+ )