thecr7guy commited on
Commit
9139b81
·
verified ·
1 Parent(s): 6ab8f64

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -20
model.py CHANGED
@@ -1,6 +1,5 @@
1
  import torch
2
  import torch.nn as nn
3
- # from torchinfo import summary
4
 
5
 
6
  class RD_block(nn.Module):
@@ -61,31 +60,38 @@ class UpsampleBlock(nn.Module):
61
  return self.act(self.conv(self.upsample(x)))
62
 
63
 
64
- class RRDBNet(nn.Module):
65
  def __init__(self, in_channels, out_channels, channels, growth_channels, upscale_factor, residual_beta):
66
- super(RRDBNet, self).__init__()
67
  self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3,
68
  stride=1, padding=1)
69
- self.res_block = nn.Sequential(*[RRD_block(channels, growth_channels, residual_beta) for _ in range(23)])
70
-
 
 
 
71
  self.conv2 = nn.Conv2d(channels, channels, kernel_size=3,
72
  stride=1, padding=1)
73
-
74
  self.upsample = nn.Sequential(
75
  UpsampleBlock(channels, upscale_factor), UpsampleBlock(channels, upscale_factor),
76
  )
77
 
78
  self.conv3 = nn.Sequential(
79
  nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
80
- nn.LeakyReLU(0.2, True)
81
  )
82
 
83
  self.conv4 = nn.Conv2d(channels, out_channels, (3, 3), (1, 1), (1, 1))
84
 
85
  def forward(self, x):
86
- temp = x
87
  out1 = self.conv1(x)
88
- out2 = self.conv2(self.res_block(out1))
 
 
 
 
 
 
89
  out3 = torch.add(out2, out1)
90
  out4 = self.upsample(out3)
91
  out5 = self.conv3(out4)
@@ -199,6 +205,8 @@ class Discriminator(nn.Module):
199
  out = self.classifier(out)
200
 
201
  return out
 
 
202
  #############################################
203
  def weights_init(m):
204
  if isinstance(m, nn.Conv2d):
@@ -206,14 +214,3 @@ def weights_init(m):
206
  m.weight.data *= 0.1
207
  if m.bias is not None:
208
  nn.init.constant_(m.bias, 0)
209
-
210
- #
211
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
212
- # gen = RRDBNet(3, 3, 64, 32, 2, 0.2).to(device)
213
- # # # gen_opt = torch.optim.Adam(gen.parameters(), lr=1e-4, betas=(0.9, 0.999))
214
- # # # gen_model = gen.apply(weights_init)
215
- # summary(gen, input_size=(16, 3, 64, 64))
216
- # # # dis = Discriminator().to(device)
217
- # # # summary(dis, input_size=(16, 3, 256, 256))
218
-
219
- #############################################
 
1
  import torch
2
  import torch.nn as nn
 
3
 
4
 
5
  class RD_block(nn.Module):
 
60
  return self.act(self.conv(self.upsample(x)))
61
 
62
 
63
+ class DRRRDBNet(nn.Module):
64
  def __init__(self, in_channels, out_channels, channels, growth_channels, upscale_factor, residual_beta):
65
+ super(DRRRDBNet, self).__init__()
66
  self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3,
67
  stride=1, padding=1)
68
+ self.res_block = nn.Sequential(*[RRD_block(channels, growth_channels, residual_beta) for _ in range(6)])
69
+ self.res_block2 = nn.Sequential(*[RRD_block(channels, growth_channels, residual_beta) for _ in range(6)])
70
+ self.res_block3 = nn.Sequential(*[RRD_block(channels, growth_channels, residual_beta) for _ in range(6)])
71
+ self.res_block4 = nn.Sequential(*[RRD_block(channels, growth_channels, residual_beta) for _ in range(5)])
72
+ self.dropout = nn.Dropout(0.1)
73
  self.conv2 = nn.Conv2d(channels, channels, kernel_size=3,
74
  stride=1, padding=1)
 
75
  self.upsample = nn.Sequential(
76
  UpsampleBlock(channels, upscale_factor), UpsampleBlock(channels, upscale_factor),
77
  )
78
 
79
  self.conv3 = nn.Sequential(
80
  nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
81
+ nn.LeakyReLU(0.2, True),
82
  )
83
 
84
  self.conv4 = nn.Conv2d(channels, out_channels, (3, 3), (1, 1), (1, 1))
85
 
86
  def forward(self, x):
 
87
  out1 = self.conv1(x)
88
+ t_out1 = self.res_block(out1)
89
+ t_out2 = self.dropout(t_out1)
90
+ t_out3 = self.res_block2(t_out2)
91
+ t_out4 = self.dropout(t_out3)
92
+ t_out5 = self.res_block3(t_out4)
93
+ t_out6 = self.dropout(t_out5)
94
+ out2 = self.conv2(self.res_block4(t_out6))
95
  out3 = torch.add(out2, out1)
96
  out4 = self.upsample(out3)
97
  out5 = self.conv3(out4)
 
205
  out = self.classifier(out)
206
 
207
  return out
208
+
209
+
210
  #############################################
211
  def weights_init(m):
212
  if isinstance(m, nn.Conv2d):
 
214
  m.weight.data *= 0.1
215
  if m.bias is not None:
216
  nn.init.constant_(m.bias, 0)