therealcyberlord commited on
Commit
efa0b4b
·
1 Parent(s): 8698c52

Upload SRGAN.py

Browse files

model class for the ESRGAN model

Files changed (1) hide show
  1. SRGAN.py +79 -0
SRGAN.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class DenseResidualBlock(nn.Module):
5
+ """
6
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
7
+ """
8
+
9
+ def __init__(self, filters, res_scale=0.2):
10
+ super(DenseResidualBlock, self).__init__()
11
+ self.res_scale = res_scale
12
+
13
+ def block(in_features, non_linearity=True):
14
+ layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
15
+ if non_linearity:
16
+ layers += [nn.LeakyReLU()]
17
+ return nn.Sequential(*layers)
18
+
19
+ self.b1 = block(in_features=1 * filters)
20
+ self.b2 = block(in_features=2 * filters)
21
+ self.b3 = block(in_features=3 * filters)
22
+ self.b4 = block(in_features=4 * filters)
23
+ self.b5 = block(in_features=5 * filters, non_linearity=False)
24
+ self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]
25
+
26
+ def forward(self, x):
27
+ inputs = x
28
+ for block in self.blocks:
29
+ out = block(inputs)
30
+ inputs = torch.cat([inputs, out], 1)
31
+ return out.mul(self.res_scale) + x
32
+
33
+
34
+ class ResidualInResidualDenseBlock(nn.Module):
35
+ def __init__(self, filters, res_scale=0.2):
36
+ super(ResidualInResidualDenseBlock, self).__init__()
37
+ self.res_scale = res_scale
38
+ self.dense_blocks = nn.Sequential(
39
+ DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
40
+ )
41
+
42
+ def forward(self, x):
43
+ return self.dense_blocks(x).mul(self.res_scale) + x
44
+
45
+
46
+ class GeneratorRRDB(nn.Module):
47
+ def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
48
+ super(GeneratorRRDB, self).__init__()
49
+
50
+ # First layer
51
+ self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
52
+ # Residual blocks
53
+ self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
54
+ # Second conv layer post residual blocks
55
+ self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
56
+ # Upsampling layers
57
+ upsample_layers = []
58
+ for _ in range(num_upsample):
59
+ upsample_layers += [
60
+ nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
61
+ nn.LeakyReLU(),
62
+ nn.PixelShuffle(upscale_factor=2),
63
+ ]
64
+ self.upsampling = nn.Sequential(*upsample_layers)
65
+ # Final output block
66
+ self.conv3 = nn.Sequential(
67
+ nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
68
+ nn.LeakyReLU(),
69
+ nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
70
+ )
71
+
72
+ def forward(self, x):
73
+ out1 = self.conv1(x)
74
+ out = self.res_blocks(out1)
75
+ out2 = self.conv2(out)
76
+ out = torch.add(out1, out2)
77
+ out = self.upsampling(out)
78
+ out = self.conv3(out)
79
+ return out