Spaces:
Runtime error
Runtime error
Update models/DehazeFormer.py
Browse files- models/DehazeFormer.py +114 -1
models/DehazeFormer.py
CHANGED
|
@@ -471,4 +471,117 @@ class MCT(nn.Module):
|
|
| 471 |
x_d = F.interpolate(x, (self.ts, self.ts), mode='area')
|
| 472 |
param = self.basenet(x_d)
|
| 473 |
out = self.mapping(x, param)
|
| 474 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
x_d = F.interpolate(x, (self.ts, self.ts), mode='area')
|
| 472 |
param = self.basenet(x_d)
|
| 473 |
out = self.mapping(x, param)
|
| 474 |
+
return out
|
| 475 |
+
class DehazeFormer(nn.Module):
|
| 476 |
+
def __init__(self, in_chans=3, out_chans=4, window_size=8,
|
| 477 |
+
embed_dims=[24, 48, 96, 48, 24],
|
| 478 |
+
mlp_ratios=[2., 4., 4., 2., 2.],
|
| 479 |
+
depths=[16, 16, 16, 8, 8],
|
| 480 |
+
num_heads=[2, 4, 6, 1, 1],
|
| 481 |
+
attn_ratio=[1/4, 1/2, 3/4, 0, 0],
|
| 482 |
+
conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'],
|
| 483 |
+
norm_layer=[RLN, RLN, RLN, RLN, RLN]):
|
| 484 |
+
super(DehazeFormer, self).__init__()
|
| 485 |
+
|
| 486 |
+
# setting
|
| 487 |
+
self.patch_size = 4
|
| 488 |
+
self.window_size = window_size
|
| 489 |
+
self.mlp_ratios = mlp_ratios
|
| 490 |
+
|
| 491 |
+
# split image into non-overlapping patches
|
| 492 |
+
self.patch_embed = PatchEmbed(
|
| 493 |
+
patch_size=1, in_chans=in_chans, embed_dim=embed_dims[0], kernel_size=3)
|
| 494 |
+
|
| 495 |
+
# backbone
|
| 496 |
+
self.layer1 = BasicLayer(network_depth=sum(depths), dim=embed_dims[0], depth=depths[0],
|
| 497 |
+
num_heads=num_heads[0], mlp_ratio=mlp_ratios[0],
|
| 498 |
+
norm_layer=norm_layer[0], window_size=window_size,
|
| 499 |
+
attn_ratio=attn_ratio[0], attn_loc='last', conv_type=conv_type[0])
|
| 500 |
+
|
| 501 |
+
self.patch_merge1 = PatchEmbed(
|
| 502 |
+
patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
|
| 503 |
+
|
| 504 |
+
self.skip1 = nn.Conv2d(embed_dims[0], embed_dims[0], 1)
|
| 505 |
+
|
| 506 |
+
self.layer2 = BasicLayer(network_depth=sum(depths), dim=embed_dims[1], depth=depths[1],
|
| 507 |
+
num_heads=num_heads[1], mlp_ratio=mlp_ratios[1],
|
| 508 |
+
norm_layer=norm_layer[1], window_size=window_size,
|
| 509 |
+
attn_ratio=attn_ratio[1], attn_loc='last', conv_type=conv_type[1])
|
| 510 |
+
|
| 511 |
+
self.patch_merge2 = PatchEmbed(
|
| 512 |
+
patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
|
| 513 |
+
|
| 514 |
+
self.skip2 = nn.Conv2d(embed_dims[1], embed_dims[1], 1)
|
| 515 |
+
|
| 516 |
+
self.layer3 = BasicLayer(network_depth=sum(depths), dim=embed_dims[2], depth=depths[2],
|
| 517 |
+
num_heads=num_heads[2], mlp_ratio=mlp_ratios[2],
|
| 518 |
+
norm_layer=norm_layer[2], window_size=window_size,
|
| 519 |
+
attn_ratio=attn_ratio[2], attn_loc='last', conv_type=conv_type[2])
|
| 520 |
+
|
| 521 |
+
self.patch_split1 = PatchUnEmbed(
|
| 522 |
+
patch_size=2, out_chans=embed_dims[3], embed_dim=embed_dims[2])
|
| 523 |
+
|
| 524 |
+
assert embed_dims[1] == embed_dims[3]
|
| 525 |
+
self.fusion1 = SKFusion(embed_dims[3])
|
| 526 |
+
|
| 527 |
+
self.layer4 = BasicLayer(network_depth=sum(depths), dim=embed_dims[3], depth=depths[3],
|
| 528 |
+
num_heads=num_heads[3], mlp_ratio=mlp_ratios[3],
|
| 529 |
+
norm_layer=norm_layer[3], window_size=window_size,
|
| 530 |
+
attn_ratio=attn_ratio[3], attn_loc='last', conv_type=conv_type[3])
|
| 531 |
+
|
| 532 |
+
self.patch_split2 = PatchUnEmbed(
|
| 533 |
+
patch_size=2, out_chans=embed_dims[4], embed_dim=embed_dims[3])
|
| 534 |
+
|
| 535 |
+
assert embed_dims[0] == embed_dims[4]
|
| 536 |
+
self.fusion2 = SKFusion(embed_dims[4])
|
| 537 |
+
|
| 538 |
+
self.layer5 = BasicLayer(network_depth=sum(depths), dim=embed_dims[4], depth=depths[4],
|
| 539 |
+
num_heads=num_heads[4], mlp_ratio=mlp_ratios[4],
|
| 540 |
+
norm_layer=norm_layer[4], window_size=window_size,
|
| 541 |
+
attn_ratio=attn_ratio[4], attn_loc='last', conv_type=conv_type[4])
|
| 542 |
+
|
| 543 |
+
# merge non-overlapping patches into image
|
| 544 |
+
self.patch_unembed = PatchUnEmbed(
|
| 545 |
+
patch_size=1, out_chans=out_chans, embed_dim=embed_dims[4], kernel_size=3)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def check_image_size(self, x):
|
| 549 |
+
# NOTE: for I2I test
|
| 550 |
+
_, _, h, w = x.size()
|
| 551 |
+
mod_pad_h = (self.patch_size - h % self.patch_size) % self.patch_size
|
| 552 |
+
mod_pad_w = (self.patch_size - w % self.patch_size) % self.patch_size
|
| 553 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
| 554 |
+
return x
|
| 555 |
+
|
| 556 |
+
def forward_features(self, x):
|
| 557 |
+
x = self.patch_embed(x)
|
| 558 |
+
x = self.layer1(x)
|
| 559 |
+
skip1 = x
|
| 560 |
+
|
| 561 |
+
x = self.patch_merge1(x)
|
| 562 |
+
x = self.layer2(x)
|
| 563 |
+
skip2 = x
|
| 564 |
+
|
| 565 |
+
x = self.patch_merge2(x)
|
| 566 |
+
x = self.layer3(x)
|
| 567 |
+
x = self.patch_split1(x)
|
| 568 |
+
|
| 569 |
+
x = self.fusion1([x, self.skip2(skip2)]) + x
|
| 570 |
+
x = self.layer4(x)
|
| 571 |
+
x = self.patch_split2(x)
|
| 572 |
+
|
| 573 |
+
x = self.fusion2([x, self.skip1(skip1)]) + x
|
| 574 |
+
x = self.layer5(x)
|
| 575 |
+
x = self.patch_unembed(x)
|
| 576 |
+
return x
|
| 577 |
+
|
| 578 |
+
def forward(self, x):
|
| 579 |
+
H, W = x.shape[2:]
|
| 580 |
+
x = self.check_image_size(x)
|
| 581 |
+
|
| 582 |
+
feat = self.forward_features(x)
|
| 583 |
+
K, B = torch.split(feat, (1, 3), dim=1)
|
| 584 |
+
|
| 585 |
+
x = K * x - B + x
|
| 586 |
+
x = x[:, :, :H, :W]
|
| 587 |
+
return x
|