omaralaa2004 commited on
Commit
0b88c0c
·
verified ·
1 Parent(s): 9c81b98

Update realesrgan/archs/rrdbnet_arch.py

Browse files
Files changed (1) hide show
  1. realesrgan/archs/rrdbnet_arch.py +61 -52
realesrgan/archs/rrdbnet_arch.py CHANGED
@@ -1,52 +1,61 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class ResidualDenseBlock(nn.Module):
6
- def __init__(self, nf=64, gc=32):
7
- super().__init__()
8
- self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1)
9
- self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1)
10
- self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1)
11
- self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1)
12
- self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1)
13
- self.lrelu = nn.LeakyReLU(0.2, inplace=True)
14
-
15
- def forward(self, x):
16
- x1 = self.lrelu(self.conv1(x))
17
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
18
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
19
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
20
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
21
- return x + 0.2 * x5
22
-
23
- class RRDB(nn.Module):
24
- def __init__(self, nf, gc=32):
25
- super().__init__()
26
- self.rdb1 = ResidualDenseBlock(nf, gc)
27
- self.rdb2 = ResidualDenseBlock(nf, gc)
28
- self.rdb3 = ResidualDenseBlock(nf, gc)
29
-
30
- def forward(self, x):
31
- return x + 0.2 * self.rdb3(self.rdb2(self.rdb1(x)))
32
-
33
- class RRDBNet(nn.Module):
34
- def __init__(self, in_nc, out_nc, nf, nb, gc=32):
35
- super().__init__()
36
- self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1)
37
- self.RRDB_trunk = nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)])
38
- self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1)
39
- self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1)
40
- self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1)
41
- self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1)
42
- self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1)
43
- self.lrelu = nn.LeakyReLU(0.2, inplace=True)
44
-
45
- def forward(self, x):
46
- fea = self.conv_first(x)
47
- trunk = self.trunk_conv(self.RRDB_trunk(fea))
48
- fea = fea + trunk
49
- fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
50
- fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
51
- out = self.conv_last(self.lrelu(self.HRconv(fea)))
52
- return out
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualDenseBlock(nn.Module):
7
+ """Residual Dense Block used in RRDB"""
8
+
9
+ def __init__(self, nf=64, gc=32):
10
+ super(ResidualDenseBlock, self).__init__()
11
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1)
12
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1)
13
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1)
14
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1)
15
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1)
16
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
17
+
18
+ def forward(self, x):
19
+ x1 = self.lrelu(self.conv1(x))
20
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
21
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
22
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
23
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
24
+ return x + 0.2 * x5
25
+
26
+
27
+ class RRDB(nn.Module):
28
+ """Residual in Residual Dense Block"""
29
+
30
+ def __init__(self, nf, gc=32):
31
+ super(RRDB, self).__init__()
32
+ self.rdb1 = ResidualDenseBlock(nf, gc)
33
+ self.rdb2 = ResidualDenseBlock(nf, gc)
34
+ self.rdb3 = ResidualDenseBlock(nf, gc)
35
+
36
+ def forward(self, x):
37
+ return x + 0.2 * self.rdb3(self.rdb2(self.rdb1(x)))
38
+
39
+
40
+ class RRDBNet(nn.Module):
41
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32):
42
+ super(RRDBNet, self).__init__()
43
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1)
44
+ self.body = nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)])
45
+ self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1)
46
+
47
+ # upsampling
48
+ self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1)
49
+ self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1)
50
+ self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1)
51
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1)
52
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
53
+
54
+ def forward(self, x):
55
+ fea = self.conv_first(x)
56
+ trunk = self.conv_body(self.body(fea))
57
+ fea = fea + trunk
58
+ fea = self.lrelu(self.conv_up1(F.interpolate(fea, scale_factor=2, mode='nearest')))
59
+ fea = self.lrelu(self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest')))
60
+ out = self.conv_last(self.lrelu(self.conv_hr(fea)))
61
+ return out