devbernie commited on
Commit
0ed552c
·
verified ·
1 Parent(s): 8c70566
Files changed (1) hide show
  1. app.py +50 -4
app.py CHANGED
@@ -18,11 +18,57 @@ MAX_IMAGE_SIZE = (1024, 1024)
18
  class RRDBNet(torch.nn.Module):
19
  def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32):
20
  super(RRDBNet, self).__init__()
21
- self.model = self._make_network(in_nc, out_nc, nf, nb, gc)
 
 
 
 
 
 
 
22
 
23
- def _make_network(self, in_nc, out_nc, nf, nb, gc):
24
- # [Original architecture implementation here...]
25
- # Full implementation: https://github.com/xinntao/ESRGAN/blob/master/RRDBNet_arch.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def load_model() -> torch.nn.Module:
28
  """Download and load ESRGAN model"""
 
18
  class RRDBNet(torch.nn.Module):
19
  def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32):
20
  super(RRDBNet, self).__init__()
21
+ self.conv_first = torch.nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
22
+ self.RRDB_trunk = torch.nn.ModuleList([RRDB(nf, gc=gc) for _ in range(nb)])
23
+ self.trunk_conv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
24
+ self.upconv1 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
25
+ self.upconv2 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
26
+ self.HRconv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
27
+ self.conv_last = torch.nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
28
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
29
 
30
+ def forward(self, x):
31
+ fea = self.conv_first(x)
32
+ trunk = fea.clone()
33
+ for block in self.RRDB_trunk:
34
+ trunk = block(trunk)
35
+ trunk = self.trunk_conv(trunk)
36
+ fea = fea + trunk
37
+ fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
38
+ fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
39
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
40
+ return out
41
+
42
+ class RRDB(torch.nn.Module):
43
+ def __init__(self, nf, gc=32):
44
+ super(RRDB, self).__init__()
45
+ self.RDB1 = ResidualDenseBlock(nf, gc)
46
+ self.RDB2 = ResidualDenseBlock(nf, gc)
47
+ self.RDB3 = ResidualDenseBlock(nf, gc)
48
+
49
+ def forward(self, x):
50
+ out = self.RDB1(x)
51
+ out = self.RDB2(out)
52
+ out = self.RDB3(out)
53
+ return out * 0.2 + x
54
+
55
+ class ResidualDenseBlock(torch.nn.Module):
56
+ def __init__(self, nf=64, gc=32, bias=True):
57
+ super(ResidualDenseBlock, self).__init__()
58
+ self.conv1 = torch.nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
59
+ self.conv2 = torch.nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
60
+ self.conv3 = torch.nn.Conv2d(nf + 2*gc, gc, 3, 1, 1, bias=bias)
61
+ self.conv4 = torch.nn.Conv2d(nf + 3*gc, gc, 3, 1, 1, bias=bias)
62
+ self.conv5 = torch.nn.Conv2d(nf + 4*gc, nf, 3, 1, 1, bias=bias)
63
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
64
+
65
+ def forward(self, x):
66
+ x1 = self.lrelu(self.conv1(x))
67
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
68
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
69
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
70
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
71
+ return x5 * 0.2 + x
72
 
73
  def load_model() -> torch.nn.Module:
74
  """Download and load ESRGAN model"""