InPeerReview commited on
Commit
c47b9ca
·
verified ·
1 Parent(s): 2317bc0

Delete rscd/models/decoderheads/xformer.py

Browse files
Files changed (1) hide show
  1. rscd/models/decoderheads/xformer.py +0 -598
rscd/models/decoderheads/xformer.py DELETED
@@ -1,598 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.cuda.amp import autocast
4
- from rscd.models.decoderheads.vision_lstm import ViLBlock, SequenceTraversal
5
- from torch.nn import functional as F
6
- from functools import partial
7
- from rscd.models.backbones.lib_mamba.vmambanew import SS2D
8
- import pywt
9
-
10
- class PA(nn.Module):
11
- def __init__(self, dim, norm_layer, act_layer):
12
- super().__init__()
13
- self.p_conv = nn.Sequential(
14
- nn.Conv2d(dim, dim*4, 1, bias=False),
15
- norm_layer(dim*4),
16
- act_layer(),
17
- nn.Conv2d(dim*4, dim, 1, bias=False)
18
- )
19
- self.gate_fn = nn.Sigmoid()
20
-
21
- def forward(self, x):
22
- att = self.p_conv(x)
23
- x = x * self.gate_fn(att)
24
-
25
- return x
26
-
27
- class Mish(nn.Module):
28
- def __init__(self):
29
- super().__init__()
30
-
31
- def forward(self, x):
32
- return x * torch.tanh(F.softplus(x))
33
-
34
-
35
- class _ScaleModule(nn.Module):
36
- def __init__(self, dims, init_scale=1.0):
37
- super().__init__()
38
- self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
39
-
40
- def forward(self, x):
41
- return torch.mul(self.weight, x)
42
-
43
- def create_wavelet_filter(wave, in_size, out_size, dtype=torch.float):
44
- w = pywt.Wavelet(wave)
45
- dec_hi = torch.tensor(w.dec_hi[::-1], dtype=dtype)
46
- dec_lo = torch.tensor(w.dec_lo[::-1], dtype=dtype)
47
- dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
48
- dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
49
- dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
50
- dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
51
- dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
52
- rec_hi = torch.tensor(w.rec_hi[::-1], dtype=dtype).flip(dims=[0])
53
- rec_lo = torch.tensor(w.rec_lo[::-1], dtype=dtype).flip(dims=[0])
54
- rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
55
- rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
56
- rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
57
- rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
58
- rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
59
- return dec_filters, rec_filters
60
-
61
- def wavelet_transform(x, filters):
62
- b, c, h, w = x.shape
63
- pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
64
- x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)
65
- x = x.reshape(b, c, 4, h // 2, w // 2)
66
- return x
67
-
68
- def inverse_wavelet_transform(x, filters):
69
- b, c, _, h_half, w_half = x.shape
70
- pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
71
- x = x.reshape(b, c * 4, h_half, w_half)
72
- x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)
73
- return x
74
-
75
- class MBWTConv2d(nn.Module):
76
- def __init__(self, in_channels, kernel_size=5, wt_levels=1, wt_type='db1', ssm_ratio=1, forward_type="v05"):
77
- super().__init__()
78
- assert in_channels == in_channels
79
- self.wt_levels = wt_levels
80
- self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels)
81
- self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
82
- self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
83
- self.wt_function = partial(wavelet_transform, filters=self.wt_filter)
84
- self.iwt_function = partial(inverse_wavelet_transform, filters=self.iwt_filter)
85
- self.global_atten = SS2D(d_model=in_channels, d_state=1, ssm_ratio=ssm_ratio, initialize="v2",
86
- forward_type=forward_type, channel_first=True, k_group=2)
87
- self.base_scale = _ScaleModule([1, in_channels, 1, 1])
88
- self.wavelet_convs = nn.ModuleList([
89
- nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', groups=in_channels * 4)
90
- for _ in range(wt_levels)
91
- ])
92
- self.wavelet_scale = nn.ModuleList([
93
- _ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1)
94
- for _ in range(wt_levels)
95
- ])
96
-
97
- def forward(self, x):
98
- x_ll_in_levels, x_h_in_levels, shapes_in_levels = [], [], []
99
- curr_x_ll = x
100
- for i in range(self.wt_levels):
101
- curr_shape = curr_x_ll.shape
102
- shapes_in_levels.append(curr_shape)
103
- if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
104
- curr_x_ll = F.pad(curr_x_ll, (0, curr_shape[3] % 2, 0, curr_shape[2] % 2))
105
- curr_x = self.wt_function(curr_x_ll)
106
- curr_x_ll = curr_x[:, :, 0, :, :]
107
- shape_x = curr_x.shape
108
- curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
109
- curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag)).reshape(shape_x)
110
- x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
111
- x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])
112
- next_x_ll = 0
113
- for i in range(self.wt_levels - 1, -1, -1):
114
- curr_x_ll = x_ll_in_levels.pop() + next_x_ll
115
- curr_x = torch.cat([curr_x_ll.unsqueeze(2), x_h_in_levels.pop()], dim=2)
116
- next_x_ll = self.iwt_function(curr_x)
117
- next_x_ll = next_x_ll[:, :, :shapes_in_levels[i][2], :shapes_in_levels[i][3]]
118
- x_tag = next_x_ll
119
- x = self.base_scale(self.global_atten(x)) + x_tag
120
- return x
121
-
122
- class ChannelAttention(nn.Module):
123
- def __init__(self, in_planes, ratio=16):
124
- super(ChannelAttention, self).__init__()
125
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
126
- self.max_pool = nn.AdaptiveMaxPool2d(1)
127
-
128
- self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
129
- self.relu1 = nn.ReLU()
130
- self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
131
- self.sigmoid = nn.Sigmoid()
132
-
133
- def forward(self, x):
134
- avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
135
- max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
136
- out = avg_out + max_out
137
- return self.sigmoid(out)
138
-
139
- # 空间注意力模块
140
- class SpatialAttention(nn.Module):
141
- def __init__(self, kernel_size=7):
142
- super(SpatialAttention, self).__init__()
143
-
144
- assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
145
- padding = 3 if kernel_size == 7 else 1
146
-
147
- self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
148
- self.sigmoid = nn.Sigmoid()
149
-
150
- def forward(self, x):
151
- avg_out = torch.mean(x, dim=1, keepdim=True)
152
- max_out, _ = torch.max(x, dim=1, keepdim=True)
153
- x = torch.cat([avg_out, max_out], dim=1)
154
- x = self.conv1(x)
155
- return self.sigmoid(x)
156
-
157
- # CBAM 注意力模块
158
- class CBAM(nn.Module):
159
- def __init__(self, in_planes):
160
- super(CBAM, self).__init__()
161
- self.ca = ChannelAttention(in_planes)
162
- self.sa = SpatialAttention()
163
-
164
- def forward(self, x):
165
- x = self.ca(x) * x
166
- x = self.sa(x) * x
167
- return x
168
-
169
- class DynamicConv2d(nn.Module):
170
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=False, num_experts=4):
171
- super(DynamicConv2d, self).__init__()
172
- self.in_channels = in_channels
173
- self.out_channels = out_channels
174
- self.kernel_size = kernel_size
175
- self.stride = stride
176
- self.padding = padding
177
- self.groups = groups
178
- self.bias = bias
179
- self.num_experts = num_experts
180
-
181
- self.experts = nn.ModuleList([
182
- nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias)
183
- for _ in range(num_experts)
184
- ])
185
- self.gating = nn.Sequential(
186
- nn.AdaptiveAvgPool2d(1),
187
- nn.Conv2d(in_channels, num_experts, 1, bias=False),
188
- nn.Softmax(dim=1)
189
- )
190
-
191
- def forward(self, x):
192
- gates = self.gating(x)
193
- gates = gates.view(x.size(0), self.num_experts, 1, 1, 1)
194
- outputs = []
195
- for i, expert in enumerate(self.experts):
196
- outputs.append(expert(x).unsqueeze(1))
197
- outputs = torch.cat(outputs, dim=1)
198
- out = (gates * outputs).sum(dim=1)
199
- return out
200
-
201
- class DWConv2d_BN_ReLU(nn.Sequential):
202
- def __init__(self, in_channels, out_channels, kernel_size=3):
203
- super().__init__()
204
- self.add_module('dwconv3x3', DynamicConv2d(in_channels, in_channels, kernel_size=kernel_size,
205
- stride=1, padding=kernel_size // 2, groups=in_channels, bias=False))
206
- self.add_module('bn1', nn.BatchNorm2d(in_channels))
207
- self.add_module('relu', Mish())
208
- self.add_module('dwconv1x1', nn.Conv2d(in_channels, out_channels, kernel_size=1,
209
- stride=1, padding=0, groups=in_channels, bias=False))
210
- self.add_module('bn2', nn.BatchNorm2d(out_channels))
211
-
212
- class Conv2d_BN(nn.Sequential):
213
- def __init__(self, a, b, ks=1, stride=1, pad=0, groups=1):
214
- super().__init__()
215
- self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, groups=groups, bias=False))
216
- self.add_module('bn', nn.BatchNorm2d(b))
217
-
218
- class FFN(nn.Module):
219
- def __init__(self, ed, h):
220
- super().__init__()
221
- self.pw1 = Conv2d_BN(ed, h)
222
- self.act = Mish()
223
- self.pw2 = Conv2d_BN(h, ed)
224
-
225
- def forward(self, x):
226
- return self.pw2(self.act(self.pw1(x)))
227
-
228
- class StochasticDepth(nn.Module):
229
- def __init__(self, survival_prob=0.8):
230
- super().__init__()
231
- self.survival_prob = survival_prob
232
-
233
- def forward(self, x):
234
- if not self.training:
235
- return x
236
- batch_size = x.shape[0]
237
- random_tensor = self.survival_prob + torch.rand([batch_size, 1, 1, 1], dtype=x.dtype, device=x.device)
238
- binary_tensor = torch.floor(random_tensor)
239
- return x * binary_tensor / self.survival_prob
240
-
241
- class Residual(nn.Module):
242
- def __init__(self, m, survival_prob=0.8):
243
- super().__init__()
244
- self.m = m
245
- self.stochastic_depth = StochasticDepth(survival_prob)
246
-
247
- def forward(self, x):
248
- return x + self.stochastic_depth(self.m(x))
249
-
250
- class GLP_block(nn.Module):
251
- def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,pa_ratio = 0.1, kernels=3, ssm_ratio=1, forward_type="v052d"):
252
- super().__init__()
253
- self.dim = dim
254
- self.global_channels = int(global_ratio * dim)
255
- self.local_channels = int(local_ratio * dim)
256
- self.pa_channels = int(pa_ratio * dim)
257
- self.identity_channels = dim - self.global_channels - self.local_channels - self.pa_channels
258
- self.local_op = nn.ModuleList([
259
- DWConv2d_BN_ReLU(self.local_channels, self.local_channels, k)
260
- for k in [3, 5, 7]
261
- ]) if self.local_channels > 0 else nn.Identity()
262
- self.global_op = MBWTConv2d(self.global_channels, kernel_size=kernels,
263
- ssm_ratio=ssm_ratio, forward_type=forward_type) \
264
- if self.global_channels > 0 else nn.Identity()
265
- self.cbam = CBAM(dim)
266
- self.proj = nn.Sequential(
267
- Mish(),
268
- Conv2d_BN(dim, dim),
269
- CBAM(dim)
270
- )
271
-
272
- self.pa_op = PA(self.pa_channels, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU) \
273
- if self.pa_channels > 0 else nn.Identity()
274
-
275
- def forward(self, x):
276
- x1, x2, x3, x4 = torch.split(x, [self.global_channels, self.local_channels, self.identity_channels, self.pa_channels], dim=1)
277
- if isinstance(self.local_op, nn.ModuleList):
278
- local_features = [op(x2) for op in self.local_op]
279
- local_features = torch.cat(local_features, dim=1)
280
- local_features = torch.mean(local_features, dim=1, keepdim=True)
281
- local_features = local_features.expand(-1, self.local_channels, -1, -1)
282
- else:
283
- local_features = self.local_op(x2)
284
- out = torch.cat([self.global_op(x1), local_features, x3, self.pa_op(x4)], dim=1)
285
- return self.proj(out)
286
-
287
-
288
-
289
-
290
-
291
- class SASF(nn.Module):
292
- def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,pa_ratio = 0.1, kernels=3, ssm_ratio=1, forward_type="v052d"):
293
- super().__init__()
294
- self.dim = dim
295
- self.global_channels = int(global_ratio * dim)
296
- self.local_channels = int(local_ratio * dim)
297
- self.pa_channels = int(pa_ratio * dim)
298
- self.identity_channels = dim - self.global_channels - self.local_channels - self.pa_channels
299
- self.local_op = nn.ModuleList([
300
- DWConv2d_BN_ReLU(self.local_channels, self.local_channels, k)
301
- for k in [3, 5, 7]
302
- ]) if self.local_channels > 0 else nn.Identity()
303
- self.global_op = MBWTConv2d(self.global_channels, kernel_size=kernels,
304
- ssm_ratio=ssm_ratio, forward_type=forward_type) \
305
- if self.global_channels > 0 else nn.Identity()
306
- self.cbam = CBAM(dim)
307
- self.proj = nn.Sequential(
308
- Mish(),
309
- Conv2d_BN(dim, dim),
310
- CBAM(dim)
311
- )
312
-
313
- self.pa_op = PA(self.pa_channels, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU) \
314
- if self.pa_channels > 0 else nn.Identity()
315
-
316
- def forward(self, x):
317
- x1, x2, x3, x4 = torch.split(x, [self.global_channels, self.local_channels, self.identity_channels, self.pa_channels], dim=1)
318
- if isinstance(self.local_op, nn.ModuleList):
319
- local_features = [op(x2) for op in self.local_op]
320
- local_features = torch.cat(local_features, dim=1)
321
- local_features = torch.mean(local_features, dim=1, keepdim=True)
322
- local_features = local_features.expand(-1, self.local_channels, -1, -1)
323
- else:
324
- local_features = self.local_op(x2)
325
- out = torch.cat([self.global_op(x1), local_features, x3, self.pa_op(x4)], dim=1)
326
- return self.proj(out)
327
-
328
-
329
- class ViLLayer(nn.Module):
330
- def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2):
331
- super().__init__()
332
- self.dim = dim
333
- self.norm = nn.LayerNorm(dim)
334
- self.vil = ViLBlock(
335
- dim= self.dim,
336
- direction=SequenceTraversal.ROWWISE_FROM_TOP_LEFT
337
- )
338
-
339
- @autocast(enabled=False)
340
- def forward(self, x):
341
- if x.dtype == torch.float16:
342
- x = x.type(torch.float32)
343
- B, C = x.shape[:2]
344
- assert C == self.dim
345
- n_tokens = x.shape[2:].numel()
346
- img_dims = x.shape[2:]
347
- x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
348
- x_vil = self.vil(x_flat)
349
- out = x_vil.transpose(-1, -2).reshape(B, C, *img_dims)
350
-
351
- return out
352
-
353
- def dsconv_3x3(in_channel, out_channel):
354
- return nn.Sequential(
355
- nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel),
356
- nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1),
357
- nn.BatchNorm2d(out_channel),
358
- nn.ReLU(inplace=True)
359
- )
360
-
361
- def conv_1x1(in_channel, out_channel):
362
- return nn.Sequential(
363
- nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
364
- nn.BatchNorm2d(out_channel),
365
- nn.ReLU(inplace=True)
366
- )
367
-
368
- class SqueezeAxialPositionalEmbedding(nn.Module):
369
- def __init__(self, dim, shape):
370
- super().__init__()
371
-
372
- self.pos_embed = nn.Parameter(torch.randn([1, dim, shape]))
373
-
374
- def forward(self, x):
375
- B, C, N = x.shape
376
- x = x + F.interpolate(self.pos_embed, size=(N), mode='linear', align_corners=False)
377
-
378
- return x
379
-
380
- class SEBlock(nn.Module):
381
- def __init__(self, channels, r=16):
382
- super().__init__()
383
- self.fc = nn.Sequential(
384
- nn.AdaptiveAvgPool2d(1),
385
- nn.Conv2d(channels, channels//r, 1),
386
- nn.ReLU(inplace=True),
387
- nn.Conv2d(channels//r, channels, 1),
388
- nn.Sigmoid()
389
- )
390
- def forward(self, x):
391
- w = self.fc(x) # (B, C, 1, 1)
392
- return x * w
393
- class CTTF1(nn.Module):
394
- def __init__(self, in_channel, out_channel,global_ratio=0.2, local_ratio=0.2, pa_ratio = 0.2 ,kernels=5, ssm_ratio=2.0, forward_type="v052d"):
395
- super().__init__()
396
- self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
397
- self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
398
- self.catconv = dsconv_3x3(in_channel * 2, out_channel)
399
- self.convA = nn.Conv2d(in_channel, 1, 1)
400
- self.convB = nn.Conv2d(in_channel, 1, 1)
401
- self.sigmoid = nn.Sigmoid()
402
-
403
- self.mixer = Residual(GLP_block(in_channel, global_ratio, local_ratio,pa_ratio, kernels, ssm_ratio, forward_type))
404
- self.mixer2 = Residual(
405
- SASF(in_channel, global_ratio = 0, local_ratio = 0.1, pa_ratio = 0, kernels = 5, ssm_ratio = 1, forward_type = "v052d"))
406
-
407
- self.fuse = nn.Sequential(
408
- nn.Conv2d(in_channel * 3, in_channel, kernel_size=1),
409
- nn.ReLU(inplace=True)
410
- )
411
- self.cbam = CBAM(in_channel * 3)
412
-
413
- self.act = nn.SiLU()
414
- def forward(self, xA, xB):
415
- x_diffA = self.mixer(xA)
416
- x_diffB = self.mixer(xB)
417
-
418
- f1 = x_diffA
419
- f2 = x_diffB
420
- diff_signed = f1 - f2
421
- diff_abs = torch.abs(diff_signed)
422
- sum_feat = f1 + f2
423
-
424
- diff_signed = self.mixer2(diff_signed)
425
- diff_abs = self.mixer2(diff_abs)
426
- sum_feat = self.mixer2(sum_feat)
427
- # 将多路特征在通道维度拼接
428
- f_fuse = torch.cat([diff_signed, diff_abs, sum_feat], dim=1) # (B, 4C, H, W)
429
- # 再接一个 1x1 卷积降维或提炼信息
430
- f_fuse = self.cbam(f_fuse)
431
- x_diff = self.fuse(f_fuse)
432
-
433
- return x_diff
434
-
435
- class CTTF2(nn.Module):
436
- def __init__(self, in_channel, out_channel, global_ratio=0.25, local_ratio=0.25, pa_ratio=0, kernels=7,
437
- ssm_ratio=2.0, forward_type="v052d"):
438
- super().__init__()
439
- self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
440
- self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
441
- self.catconv = dsconv_3x3(in_channel * 2, out_channel)
442
- self.convA = nn.Conv2d(in_channel, 1, 1)
443
- self.convB = nn.Conv2d(in_channel, 1, 1)
444
- self.sigmoid = nn.Sigmoid()
445
-
446
-
447
- self.mixer = Residual(
448
- GLP_block(in_channel, global_ratio, local_ratio, pa_ratio, kernels, ssm_ratio, forward_type))
449
- self.mixer2 = Residual(
450
- SASF(in_channel, global_ratio=0, local_ratio=0.1, pa_ratio=0, kernels=5, ssm_ratio=1,
451
- forward_type="v052d"))
452
-
453
- self.fuse = nn.Sequential(
454
- nn.Conv2d(in_channel * 3, in_channel, kernel_size=1),
455
- nn.ReLU(inplace=True)
456
- )
457
- self.cbam = CBAM(in_channel * 3)
458
-
459
- self.act = nn.SiLU()
460
-
461
- def forward(self, xA, xB):
462
- x_diffA = self.mixer(xA)
463
- x_diffB = self.mixer(xB)
464
-
465
- f1 = x_diffA
466
- f2 = x_diffB
467
- diff_signed = f1 - f2
468
- diff_abs = torch.abs(diff_signed)
469
- sum_feat = f1 + f2
470
-
471
- diff_signed = self.mixer2(diff_signed)
472
- diff_abs = self.mixer2(diff_abs)
473
- sum_feat = self.mixer2(sum_feat)
474
- f_fuse = torch.cat([diff_signed, diff_abs, sum_feat], dim=1) # (B, 4C, H, W)
475
- f_fuse = self.cbam(f_fuse)
476
- x_diff = self.fuse(f_fuse)
477
-
478
- return x_diff
479
-
480
- class Mlp(nn.Module):
481
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=True):
482
- super().__init__()
483
- out_features = out_features or in_features
484
- hidden_features = hidden_features or in_features
485
-
486
- Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear
487
- self.fc1 = Linear(in_features, hidden_features)
488
- self.act = act_layer()
489
- self.fc2 = Linear(hidden_features, out_features)
490
- self.drop = nn.Dropout(drop)
491
-
492
- def forward(self, x):
493
- x = self.fc1(x)
494
- x = self.act(x)
495
- x = self.drop(x)
496
- x = self.fc2(x)
497
- x = self.drop(x)
498
- return x
499
-
500
- class LHBlock(nn.Module):
501
- def __init__(self, channels_l, channels_h):
502
- super().__init__()
503
- self.channels_l = channels_l
504
- self.channels_h = channels_h
505
- self.cross_size = 12
506
- self.cross_kv = nn.Sequential(
507
- nn.BatchNorm2d(channels_l),
508
- nn.AdaptiveMaxPool2d(output_size=(self.cross_size, self.cross_size)),
509
- nn.Conv2d(channels_l, 2 * channels_h, 1, 1, 0)
510
- )
511
-
512
- self.conv = conv_1x1(channels_l, channels_h)
513
- self.norm = nn.BatchNorm2d(channels_h)
514
-
515
- self.mlp_l = Mlp(in_features=channels_l, out_features=channels_l)
516
- self.mlp_h = Mlp(in_features=channels_h, out_features=channels_h)
517
-
518
- def _act_sn(self, x):
519
- _, _, H, W = x.shape
520
- inner_channel = self.cross_size * self.cross_size
521
- x = x.reshape([-1, inner_channel, H, W]) * (inner_channel**-0.5)
522
- x = F.softmax(x, dim=1)
523
- x = x.reshape([1, -1, H, W])
524
- return x
525
-
526
- def attn_h(self, x_h, cross_k, cross_v):
527
- B, _, H, W = x_h.shape
528
- x_h = self.norm(x_h)
529
- x_h = x_h.reshape([1, -1, H, W]) # n,c_in,h,w -> 1,n*c_in,h,w
530
- x_h = F.conv2d(x_h, cross_k, bias=None, stride=1, padding=0,
531
- groups=B) # 1,n*c_in,h,w -> 1,n*144,h,w (group=B)
532
- x_h = self._act_sn(x_h)
533
- x_h = F.conv2d(x_h, cross_v, bias=None, stride=1, padding=0,
534
- groups=B) # 1,n*144,h,w -> 1, n*c_in,h,w (group=B)
535
- x_h = x_h.reshape([-1, self.channels_h, H,
536
- W]) # 1, n*c_in,h,w -> n,c_in,h,w (c_in = c_out)
537
-
538
- return x_h
539
-
540
- def forward(self, x_l, x_h):
541
- x_l = x_l + self.mlp_l(x_l)
542
- x_l_conv = self.conv(x_l)
543
- x_h = x_h + F.interpolate(x_l_conv, size=x_h.shape[2:], mode='bilinear')
544
-
545
- cross_kv = self.cross_kv(x_l)
546
- cross_k, cross_v = cross_kv.split(self.channels_h, 1)
547
- cross_k = cross_k.permute(0, 2, 3, 1).reshape([-1, self.channels_h, 1, 1]) # n*144,channels_h,1,1
548
- cross_v = cross_v.reshape([-1, self.cross_size * self.cross_size, 1, 1]) # n*channels_h,144,1,1
549
-
550
- x_h = x_h + self.attn_h(x_h, cross_k, cross_v) # [4, 40, 128, 128]
551
- x_h = x_h + self.mlp_h(x_h)
552
-
553
- return x_h
554
-
555
-
556
- class CTTF(nn.Module):
557
- def __init__(self, channels=[40, 80, 192, 384]):
558
- super().__init__()
559
- self.channels = channels
560
- self.fusion0 = CTTF1(channels[0], channels[0])
561
- self.fusion1 = CTTF1(channels[1], channels[1])
562
- self.fusion2 = CTTF2(channels[2], channels[2])
563
- self.fusion3 = CTTF2(channels[3], channels[3])
564
-
565
- self.LHBlock1 = LHBlock(channels[1], channels[0])
566
- self.LHBlock2 = LHBlock(channels[2], channels[0])
567
- self.LHBlock3 = LHBlock(channels[3], channels[0])
568
-
569
- self.mlp1 = Mlp(in_features=channels[0], out_features=channels[0])
570
- self.mlp2 = Mlp(in_features=channels[0], out_features=2)
571
- self.dwc = dsconv_3x3(channels[0], channels[0])
572
-
573
- def forward(self, inputs):
574
- featuresA, featuresB = inputs
575
- # fA_0, fA_1, fA_2, fA_3 = featuresA
576
- # fB_0, fB_1, fB_2, fB_3 = featuresB
577
- x_diff_0 = self.fusion0(featuresA[0], featuresB[0]) # [4, 40, 128, 128]
578
- x_diff_1 = self.fusion1(featuresA[1], featuresB[1]) # [4, 80, 64, 64]
579
- # x_diff_2 = featuresA[2] - featuresB[2]
580
- # x_diff_3 = featuresA[3] - featuresB[3]
581
- x_diff_2 = self.fusion2(featuresA[2], featuresB[2]) # [4, 192, 32, 32]
582
- x_diff_3 = self.fusion3(featuresA[3], featuresB[3]) # [4, 384, 16, 16]
583
-
584
- x_h = x_diff_0
585
- x_h = self.LHBlock1(x_diff_1, x_h) # [4, 40, 128, 128]
586
- x_h = self.LHBlock2(x_diff_2, x_h)
587
- x_h = self.LHBlock3(x_diff_3, x_h)
588
-
589
- out = self.mlp2(self.dwc(x_h) + self.mlp1(x_h))
590
-
591
- out = F.interpolate(
592
- out,
593
- scale_factor=(4, 4),
594
- mode="bilinear",
595
- align_corners=False,
596
- )
597
- return out
598
-