Spaces:
Sleeping
Sleeping
Update realesrgan/archs/rrdbnet_arch.py
Browse files- realesrgan/archs/rrdbnet_arch.py +61 -52
realesrgan/archs/rrdbnet_arch.py
CHANGED
|
@@ -1,52 +1,61 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
self.
|
| 12 |
-
self.
|
| 13 |
-
self.
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
self.
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ResidualDenseBlock(nn.Module):
|
| 7 |
+
"""Residual Dense Block used in RRDB"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, nf=64, gc=32):
|
| 10 |
+
super(ResidualDenseBlock, self).__init__()
|
| 11 |
+
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1)
|
| 12 |
+
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1)
|
| 13 |
+
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1)
|
| 14 |
+
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1)
|
| 15 |
+
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1)
|
| 16 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
x1 = self.lrelu(self.conv1(x))
|
| 20 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
| 21 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
| 22 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
| 23 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
| 24 |
+
return x + 0.2 * x5
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class RRDB(nn.Module):
|
| 28 |
+
"""Residual in Residual Dense Block"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, nf, gc=32):
|
| 31 |
+
super(RRDB, self).__init__()
|
| 32 |
+
self.rdb1 = ResidualDenseBlock(nf, gc)
|
| 33 |
+
self.rdb2 = ResidualDenseBlock(nf, gc)
|
| 34 |
+
self.rdb3 = ResidualDenseBlock(nf, gc)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return x + 0.2 * self.rdb3(self.rdb2(self.rdb1(x)))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class RRDBNet(nn.Module):
|
| 41 |
+
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
| 42 |
+
super(RRDBNet, self).__init__()
|
| 43 |
+
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1)
|
| 44 |
+
self.body = nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)])
|
| 45 |
+
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1)
|
| 46 |
+
|
| 47 |
+
# upsampling
|
| 48 |
+
self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1)
|
| 49 |
+
self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1)
|
| 50 |
+
self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1)
|
| 51 |
+
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1)
|
| 52 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
fea = self.conv_first(x)
|
| 56 |
+
trunk = self.conv_body(self.body(fea))
|
| 57 |
+
fea = fea + trunk
|
| 58 |
+
fea = self.lrelu(self.conv_up1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
| 59 |
+
fea = self.lrelu(self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
| 60 |
+
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
|
| 61 |
+
return out
|