Tracy7777 commited on
Commit
c65ea0a
·
verified ·
1 Parent(s): d2358fe

Update models/DehazeFormer.py

Browse files
Files changed (1) hide show
  1. models/DehazeFormer.py +114 -1
models/DehazeFormer.py CHANGED
@@ -471,4 +471,117 @@ class MCT(nn.Module):
471
  x_d = F.interpolate(x, (self.ts, self.ts), mode='area')
472
  param = self.basenet(x_d)
473
  out = self.mapping(x, param)
474
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  x_d = F.interpolate(x, (self.ts, self.ts), mode='area')
472
  param = self.basenet(x_d)
473
  out = self.mapping(x, param)
474
+ return out
475
+ class DehazeFormer(nn.Module):
476
+ def __init__(self, in_chans=3, out_chans=4, window_size=8,
477
+ embed_dims=[24, 48, 96, 48, 24],
478
+ mlp_ratios=[2., 4., 4., 2., 2.],
479
+ depths=[16, 16, 16, 8, 8],
480
+ num_heads=[2, 4, 6, 1, 1],
481
+ attn_ratio=[1/4, 1/2, 3/4, 0, 0],
482
+ conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'],
483
+ norm_layer=[RLN, RLN, RLN, RLN, RLN]):
484
+ super(DehazeFormer, self).__init__()
485
+
486
+ # setting
487
+ self.patch_size = 4
488
+ self.window_size = window_size
489
+ self.mlp_ratios = mlp_ratios
490
+
491
+ # split image into non-overlapping patches
492
+ self.patch_embed = PatchEmbed(
493
+ patch_size=1, in_chans=in_chans, embed_dim=embed_dims[0], kernel_size=3)
494
+
495
+ # backbone
496
+ self.layer1 = BasicLayer(network_depth=sum(depths), dim=embed_dims[0], depth=depths[0],
497
+ num_heads=num_heads[0], mlp_ratio=mlp_ratios[0],
498
+ norm_layer=norm_layer[0], window_size=window_size,
499
+ attn_ratio=attn_ratio[0], attn_loc='last', conv_type=conv_type[0])
500
+
501
+ self.patch_merge1 = PatchEmbed(
502
+ patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
503
+
504
+ self.skip1 = nn.Conv2d(embed_dims[0], embed_dims[0], 1)
505
+
506
+ self.layer2 = BasicLayer(network_depth=sum(depths), dim=embed_dims[1], depth=depths[1],
507
+ num_heads=num_heads[1], mlp_ratio=mlp_ratios[1],
508
+ norm_layer=norm_layer[1], window_size=window_size,
509
+ attn_ratio=attn_ratio[1], attn_loc='last', conv_type=conv_type[1])
510
+
511
+ self.patch_merge2 = PatchEmbed(
512
+ patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
513
+
514
+ self.skip2 = nn.Conv2d(embed_dims[1], embed_dims[1], 1)
515
+
516
+ self.layer3 = BasicLayer(network_depth=sum(depths), dim=embed_dims[2], depth=depths[2],
517
+ num_heads=num_heads[2], mlp_ratio=mlp_ratios[2],
518
+ norm_layer=norm_layer[2], window_size=window_size,
519
+ attn_ratio=attn_ratio[2], attn_loc='last', conv_type=conv_type[2])
520
+
521
+ self.patch_split1 = PatchUnEmbed(
522
+ patch_size=2, out_chans=embed_dims[3], embed_dim=embed_dims[2])
523
+
524
+ assert embed_dims[1] == embed_dims[3]
525
+ self.fusion1 = SKFusion(embed_dims[3])
526
+
527
+ self.layer4 = BasicLayer(network_depth=sum(depths), dim=embed_dims[3], depth=depths[3],
528
+ num_heads=num_heads[3], mlp_ratio=mlp_ratios[3],
529
+ norm_layer=norm_layer[3], window_size=window_size,
530
+ attn_ratio=attn_ratio[3], attn_loc='last', conv_type=conv_type[3])
531
+
532
+ self.patch_split2 = PatchUnEmbed(
533
+ patch_size=2, out_chans=embed_dims[4], embed_dim=embed_dims[3])
534
+
535
+ assert embed_dims[0] == embed_dims[4]
536
+ self.fusion2 = SKFusion(embed_dims[4])
537
+
538
+ self.layer5 = BasicLayer(network_depth=sum(depths), dim=embed_dims[4], depth=depths[4],
539
+ num_heads=num_heads[4], mlp_ratio=mlp_ratios[4],
540
+ norm_layer=norm_layer[4], window_size=window_size,
541
+ attn_ratio=attn_ratio[4], attn_loc='last', conv_type=conv_type[4])
542
+
543
+ # merge non-overlapping patches into image
544
+ self.patch_unembed = PatchUnEmbed(
545
+ patch_size=1, out_chans=out_chans, embed_dim=embed_dims[4], kernel_size=3)
546
+
547
+
548
+ def check_image_size(self, x):
549
+ # NOTE: for I2I test
550
+ _, _, h, w = x.size()
551
+ mod_pad_h = (self.patch_size - h % self.patch_size) % self.patch_size
552
+ mod_pad_w = (self.patch_size - w % self.patch_size) % self.patch_size
553
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
554
+ return x
555
+
556
+ def forward_features(self, x):
557
+ x = self.patch_embed(x)
558
+ x = self.layer1(x)
559
+ skip1 = x
560
+
561
+ x = self.patch_merge1(x)
562
+ x = self.layer2(x)
563
+ skip2 = x
564
+
565
+ x = self.patch_merge2(x)
566
+ x = self.layer3(x)
567
+ x = self.patch_split1(x)
568
+
569
+ x = self.fusion1([x, self.skip2(skip2)]) + x
570
+ x = self.layer4(x)
571
+ x = self.patch_split2(x)
572
+
573
+ x = self.fusion2([x, self.skip1(skip1)]) + x
574
+ x = self.layer5(x)
575
+ x = self.patch_unembed(x)
576
+ return x
577
+
578
+ def forward(self, x):
579
+ H, W = x.shape[2:]
580
+ x = self.check_image_size(x)
581
+
582
+ feat = self.forward_features(x)
583
+ K, B = torch.split(feat, (1, 3), dim=1)
584
+
585
+ x = K * x - B + x
586
+ x = x[:, :, :H, :W]
587
+ return x