Ayesha-Majeed commited on
Commit
910d021
·
verified ·
1 Parent(s): 6c7e42f

Update binary_segmentation.py

Browse files
Files changed (1) hide show
  1. binary_segmentation.py +19 -38
binary_segmentation.py CHANGED
@@ -17,6 +17,7 @@ import torch
17
  from PIL import Image
18
  from torchvision import transforms
19
  import cv2
 
20
 
21
  # Configure logging
22
  logging.basicConfig(
@@ -31,15 +32,12 @@ logger.info(f"Using device: {DEVICE}")
31
 
32
 
33
  class U2NETP(torch.nn.Module):
34
- """U2-Net Portrait (U2NETP) - Lightweight segmentation model
35
-
36
- Fixed to match pretrained weights architecture with 6 stages/side outputs.
37
- """
38
 
39
  def __init__(self, in_ch=3, out_ch=1):
40
  super(U2NETP, self).__init__()
41
 
42
- # Encoder (6 stages to match pretrained weights)
43
  self.stage1 = self._make_stage(in_ch, 16, 64)
44
  self.pool12 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
45
 
@@ -50,31 +48,25 @@ class U2NETP(torch.nn.Module):
50
  self.pool34 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
51
 
52
  self.stage4 = self._make_stage(64, 16, 64)
53
- self.pool45 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
54
-
55
- self.stage5 = self._make_stage(64, 16, 64)
56
- self.pool56 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
57
 
58
  # Bridge
59
- self.stage6 = self._make_stage(64, 16, 64)
60
 
61
- # Decoder (5 decoder stages)
62
- self.stage5d = self._make_stage(128, 16, 64)
63
  self.stage4d = self._make_stage(128, 16, 64)
64
  self.stage3d = self._make_stage(128, 16, 64)
65
  self.stage2d = self._make_stage(128, 16, 64)
66
  self.stage1d = self._make_stage(128, 16, 64)
67
 
68
- # Side outputs (6 outputs to match pretrained weights)
69
  self.side1 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
70
  self.side2 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
71
  self.side3 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
72
  self.side4 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
73
  self.side5 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
74
- self.side6 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
75
 
76
- # Output fusion (6 channels now)
77
- self.outconv = torch.nn.Conv2d(6 * out_ch, out_ch, 1)
78
 
79
  def _make_stage(self, in_ch, mid_ch, out_ch):
80
  return torch.nn.Sequential(
@@ -89,7 +81,7 @@ class U2NETP(torch.nn.Module):
89
  def forward(self, x):
90
  hx = x
91
 
92
- # Encoder with skip connections
93
  hx1 = self.stage1(hx)
94
  hx = self.pool12(hx1)
95
 
@@ -100,28 +92,18 @@ class U2NETP(torch.nn.Module):
100
  hx = self.pool34(hx3)
101
 
102
  hx4 = self.stage4(hx)
103
- hx = self.pool45(hx4)
104
-
105
- hx5 = self.stage5(hx)
106
- hx = self.pool56(hx5)
107
-
108
- # Bridge
109
- hx6 = self.stage6(hx)
110
-
111
- # Decoder with skip connections
112
- hx6up = torch.nn.functional.interpolate(hx6, size=hx5.shape[2:], mode='bilinear', align_corners=True)
113
- hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
114
 
115
- hx5dup = torch.nn.functional.interpolate(hx5d, size=hx4.shape[2:], mode='bilinear', align_corners=True)
116
- hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
 
117
 
118
- hx4dup = torch.nn.functional.interpolate(hx4d, size=hx3.shape[2:], mode='bilinear', align_corners=True)
119
  hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
 
120
 
121
- hx3dup = torch.nn.functional.interpolate(hx3d, size=hx2.shape[2:], mode='bilinear', align_corners=True)
122
  hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
 
123
 
124
- hx2dup = torch.nn.functional.interpolate(hx2d, size=hx1.shape[2:], mode='bilinear', align_corners=True)
125
  hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
126
 
127
  # Side outputs
@@ -129,13 +111,12 @@ class U2NETP(torch.nn.Module):
129
  d2 = torch.nn.functional.interpolate(self.side2(hx2d), size=d1.shape[2:], mode='bilinear', align_corners=True)
130
  d3 = torch.nn.functional.interpolate(self.side3(hx3d), size=d1.shape[2:], mode='bilinear', align_corners=True)
131
  d4 = torch.nn.functional.interpolate(self.side4(hx4d), size=d1.shape[2:], mode='bilinear', align_corners=True)
132
- d5 = torch.nn.functional.interpolate(self.side5(hx5d), size=d1.shape[2:], mode='bilinear', align_corners=True)
133
- d6 = torch.nn.functional.interpolate(self.side6(hx6), size=d1.shape[2:], mode='bilinear', align_corners=True)
134
 
135
- # Fusion (6 side outputs now)
136
- d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
137
 
138
- return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)
139
 
140
 
141
  class BinarySegmenter:
 
17
  from PIL import Image
18
  from torchvision import transforms
19
  import cv2
20
+ from u2net import U2NETP
21
 
22
  # Configure logging
23
  logging.basicConfig(
 
32
 
33
 
34
  class U2NETP(torch.nn.Module):
35
+ """U2-Net Portrait (U2NETP) - Lightweight segmentation model"""
 
 
 
36
 
37
  def __init__(self, in_ch=3, out_ch=1):
38
  super(U2NETP, self).__init__()
39
 
40
+ # Encoder
41
  self.stage1 = self._make_stage(in_ch, 16, 64)
42
  self.pool12 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
43
 
 
48
  self.pool34 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
49
 
50
  self.stage4 = self._make_stage(64, 16, 64)
 
 
 
 
51
 
52
  # Bridge
53
+ self.stage5 = self._make_stage(64, 16, 64)
54
 
55
+ # Decoder
 
56
  self.stage4d = self._make_stage(128, 16, 64)
57
  self.stage3d = self._make_stage(128, 16, 64)
58
  self.stage2d = self._make_stage(128, 16, 64)
59
  self.stage1d = self._make_stage(128, 16, 64)
60
 
61
+ # Side outputs
62
  self.side1 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
63
  self.side2 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
64
  self.side3 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
65
  self.side4 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
66
  self.side5 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
 
67
 
68
+ # Output fusion
69
+ self.outconv = torch.nn.Conv2d(5 * out_ch, out_ch, 1)
70
 
71
  def _make_stage(self, in_ch, mid_ch, out_ch):
72
  return torch.nn.Sequential(
 
81
  def forward(self, x):
82
  hx = x
83
 
84
+ # Encoder
85
  hx1 = self.stage1(hx)
86
  hx = self.pool12(hx1)
87
 
 
92
  hx = self.pool34(hx3)
93
 
94
  hx4 = self.stage4(hx)
95
+ hx5 = self.stage5(hx4)
 
 
 
 
 
 
 
 
 
 
96
 
97
+ # Decoder
98
+ hx4d = self.stage4d(torch.cat((hx5, hx4), 1))
99
+ hx4dup = torch.nn.functional.interpolate(hx4d, scale_factor=2, mode='bilinear', align_corners=True)
100
 
 
101
  hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
102
+ hx3dup = torch.nn.functional.interpolate(hx3d, scale_factor=2, mode='bilinear', align_corners=True)
103
 
 
104
  hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
105
+ hx2dup = torch.nn.functional.interpolate(hx2d, scale_factor=2, mode='bilinear', align_corners=True)
106
 
 
107
  hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
108
 
109
  # Side outputs
 
111
  d2 = torch.nn.functional.interpolate(self.side2(hx2d), size=d1.shape[2:], mode='bilinear', align_corners=True)
112
  d3 = torch.nn.functional.interpolate(self.side3(hx3d), size=d1.shape[2:], mode='bilinear', align_corners=True)
113
  d4 = torch.nn.functional.interpolate(self.side4(hx4d), size=d1.shape[2:], mode='bilinear', align_corners=True)
114
+ d5 = torch.nn.functional.interpolate(self.side5(hx5), size=d1.shape[2:], mode='bilinear', align_corners=True)
 
115
 
116
+ # Fusion
117
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5), 1))
118
 
119
+ return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5)
120
 
121
 
122
  class BinarySegmenter: