ZhengPeng7 commited on
Commit
9f25bda
·
1 Parent(s): acdc9ae

For users to load in one key.

Browse files
Files changed (1) hide show
  1. birefnet.py +0 -287
birefnet.py DELETED
@@ -1,287 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from kornia.filters import laplacian
5
- from huggingface_hub import PyTorchModelHubMixin
6
-
7
- from config import Config
8
- from dataset import class_labels_TR_sorted
9
- from models.backbones.build_backbone import build_backbone
10
- from models.modules.decoder_blocks import BasicDecBlk, ResBlk, HierarAttDecBlk
11
- from models.modules.lateral_blocks import BasicLatBlk
12
- from models.modules.aspp import ASPP, ASPPDeformable
13
- from models.modules.ing import *
14
- from models.refinement.refiner import Refiner, RefinerPVTInChannels4, RefUNet
15
- from models.refinement.stem_layer import StemLayer
16
-
17
-
18
- class BiRefNet(
19
- nn.Module,
20
- PyTorchModelHubMixin,
21
- library_name="birefnet",
22
- repo_url="https://github.com/ZhengPeng7/BiRefNet",
23
- tags=['Image Segmentation', 'Background Removal', 'Mask Generation', 'Dichotomous Image Segmentation', 'Camouflaged Object Detection', 'Salient Object Detection']
24
- ):
25
- def __init__(self, bb_pretrained=True):
26
- super(BiRefNet, self).__init__()
27
- self.config = Config()
28
- self.epoch = 1
29
- self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
30
-
31
- channels = self.config.lateral_channels_in_collection
32
-
33
- if self.config.auxiliary_classification:
34
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
35
- self.cls_head = nn.Sequential(
36
- nn.Linear(channels[0], len(class_labels_TR_sorted))
37
- )
38
-
39
- if self.config.squeeze_block:
40
- self.squeeze_module = nn.Sequential(*[
41
- eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0])
42
- for _ in range(eval(self.config.squeeze_block.split('_x')[1]))
43
- ])
44
-
45
- self.decoder = Decoder(channels)
46
-
47
- if self.config.ender:
48
- self.dec_end = nn.Sequential(
49
- nn.Conv2d(1, 16, 3, 1, 1),
50
- nn.Conv2d(16, 1, 3, 1, 1),
51
- nn.ReLU(inplace=True),
52
- )
53
-
54
- # refine patch-level segmentation
55
- if self.config.refine:
56
- if self.config.refine == 'itself':
57
- self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
58
- else:
59
- self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1'))
60
-
61
- if self.config.freeze_bb:
62
- # Freeze the backbone...
63
- print(self.named_parameters())
64
- for key, value in self.named_parameters():
65
- if 'bb.' in key and 'refiner.' not in key:
66
- value.requires_grad = False
67
-
68
- def forward_enc(self, x):
69
- if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
70
- x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3)
71
- else:
72
- x1, x2, x3, x4 = self.bb(x)
73
- if self.config.mul_scl_ipt == 'cat':
74
- B, C, H, W = x.shape
75
- x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
76
- x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
77
- x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
78
- x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
79
- x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
80
- elif self.config.mul_scl_ipt == 'add':
81
- B, C, H, W = x.shape
82
- x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
83
- x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)
84
- x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)
85
- x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)
86
- x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)
87
- class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None
88
- if self.config.cxt:
89
- x4 = torch.cat(
90
- (
91
- *[
92
- F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
93
- F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
94
- F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
95
- ][-len(self.config.cxt):],
96
- x4
97
- ),
98
- dim=1
99
- )
100
- return (x1, x2, x3, x4), class_preds
101
-
102
- def forward_ori(self, x):
103
- ########## Encoder ##########
104
- (x1, x2, x3, x4), class_preds = self.forward_enc(x)
105
- if self.config.squeeze_block:
106
- x4 = self.squeeze_module(x4)
107
- ########## Decoder ##########
108
- features = [x, x1, x2, x3, x4]
109
- if self.training and self.config.out_ref:
110
- features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
111
- scaled_preds = self.decoder(features)
112
- return scaled_preds, class_preds
113
-
114
- def forward(self, x):
115
- scaled_preds, class_preds = self.forward_ori(x)
116
- class_preds_lst = [class_preds]
117
- return [scaled_preds, class_preds_lst] if self.training and 0 else scaled_preds
118
-
119
-
120
- class Decoder(nn.Module):
121
- def __init__(self, channels):
122
- super(Decoder, self).__init__()
123
- self.config = Config()
124
- DecoderBlock = eval(self.config.dec_blk)
125
- LateralBlock = eval(self.config.lat_blk)
126
-
127
- if self.config.dec_ipt:
128
- self.split = self.config.dec_ipt_split
129
- N_dec_ipt = 64
130
- DBlock = SimpleConvs
131
- ic = 64
132
- ipt_cha_opt = 1
133
- self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
134
- self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
135
- self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
136
- self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
137
- self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
138
- else:
139
- self.split = None
140
-
141
- self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[1])
142
- self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[2])
143
- self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3])
144
- self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]//2)
145
- self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt] if self.config.dec_ipt else 0), 1, 1, 1, 0))
146
-
147
- self.lateral_block4 = LateralBlock(channels[1], channels[1])
148
- self.lateral_block3 = LateralBlock(channels[2], channels[2])
149
- self.lateral_block2 = LateralBlock(channels[3], channels[3])
150
-
151
- if self.config.ms_supervision:
152
- self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0)
153
- self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0)
154
- self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0)
155
-
156
- if self.config.out_ref:
157
- _N = 16
158
- self.gdt_convs_4 = nn.Sequential(nn.Conv2d(channels[1], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
159
- self.gdt_convs_3 = nn.Sequential(nn.Conv2d(channels[2], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
160
- self.gdt_convs_2 = nn.Sequential(nn.Conv2d(channels[3], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
161
-
162
- self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
163
- self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
164
- self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
165
-
166
- self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
167
- self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
168
- self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
169
-
170
- def get_patches_batch(self, x, p):
171
- _size_h, _size_w = p.shape[2:]
172
- patches_batch = []
173
- for idx in range(x.shape[0]):
174
- columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
175
- patches_x = []
176
- for column_x in columns_x:
177
- patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
178
- patch_sample = torch.cat(patches_x, dim=1)
179
- patches_batch.append(patch_sample)
180
- return torch.cat(patches_batch, dim=0)
181
-
182
- def forward(self, features):
183
- if self.training and self.config.out_ref:
184
- outs_gdt_pred = []
185
- outs_gdt_label = []
186
- x, x1, x2, x3, x4, gdt_gt = features
187
- else:
188
- x, x1, x2, x3, x4 = features
189
- outs = []
190
-
191
- if self.config.dec_ipt:
192
- patches_batch = self.get_patches_batch(x, x4) if self.split else x
193
- x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
194
- p4 = self.decoder_block4(x4)
195
- m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
196
- if self.config.out_ref:
197
- p4_gdt = self.gdt_convs_4(p4)
198
- if self.training:
199
- # >> GT:
200
- m4_dia = m4
201
- gdt_label_main_4 = gdt_gt * F.interpolate(m4_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
202
- outs_gdt_label.append(gdt_label_main_4)
203
- # >> Pred:
204
- gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt)
205
- outs_gdt_pred.append(gdt_pred_4)
206
- gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
207
- # >> Finally:
208
- p4 = p4 * gdt_attn_4
209
- _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
210
- _p3 = _p4 + self.lateral_block4(x3)
211
-
212
- if self.config.dec_ipt:
213
- patches_batch = self.get_patches_batch(x, _p3) if self.split else x
214
- _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
215
- p3 = self.decoder_block3(_p3)
216
- m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
217
- if self.config.out_ref:
218
- p3_gdt = self.gdt_convs_3(p3)
219
- if self.training:
220
- # >> GT:
221
- # m3 --dilation--> m3_dia
222
- # G_3^gt * m3_dia --> G_3^m, which is the label of gradient
223
- m3_dia = m3
224
- gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
225
- outs_gdt_label.append(gdt_label_main_3)
226
- # >> Pred:
227
- # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
228
- # F_3^G --sigmoid--> A_3^G
229
- gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt)
230
- outs_gdt_pred.append(gdt_pred_3)
231
- gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
232
- # >> Finally:
233
- # p3 = p3 * A_3^G
234
- p3 = p3 * gdt_attn_3
235
- _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
236
- _p2 = _p3 + self.lateral_block3(x2)
237
-
238
- if self.config.dec_ipt:
239
- patches_batch = self.get_patches_batch(x, _p2) if self.split else x
240
- _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
241
- p2 = self.decoder_block2(_p2)
242
- m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
243
- if self.config.out_ref:
244
- p2_gdt = self.gdt_convs_2(p2)
245
- if self.training:
246
- # >> GT:
247
- m2_dia = m2
248
- gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
249
- outs_gdt_label.append(gdt_label_main_2)
250
- # >> Pred:
251
- gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt)
252
- outs_gdt_pred.append(gdt_pred_2)
253
- gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
254
- # >> Finally:
255
- p2 = p2 * gdt_attn_2
256
- _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
257
- _p1 = _p2 + self.lateral_block2(x1)
258
-
259
- if self.config.dec_ipt:
260
- patches_batch = self.get_patches_batch(x, _p1) if self.split else x
261
- _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
262
- _p1 = self.decoder_block1(_p1)
263
- _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
264
-
265
- if self.config.dec_ipt:
266
- patches_batch = self.get_patches_batch(x, _p1) if self.split else x
267
- _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
268
- p1_out = self.conv_out1(_p1)
269
-
270
- if self.config.ms_supervision:
271
- outs.append(m4)
272
- outs.append(m3)
273
- outs.append(m2)
274
- outs.append(p1_out)
275
- return outs if not (self.config.out_ref and self.training) else ([outs_gdt_pred, outs_gdt_label], outs)
276
-
277
-
278
- class SimpleConvs(nn.Module):
279
- def __init__(
280
- self, in_channels: int, out_channels: int, inter_channels=64
281
- ) -> None:
282
- super().__init__()
283
- self.conv1 = nn.Conv2d(in_channels, inter_channels, 3, 1, 1)
284
- self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1)
285
-
286
- def forward(self, x):
287
- return self.conv_out(self.conv1(x))