| | import torch
|
| | import torch.nn as nn
|
| |
|
| | from . import common
|
| | from .LamResNet import ResNet
|
| |
|
| |
|
| | def build_model(args):
|
| | return RecLamResNet(args)
|
| |
|
| |
|
| | class conv_end(nn.Module):
|
| | def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2):
|
| | super(conv_end, self).__init__()
|
| |
|
| | modules = [
|
| | common.default_conv(in_channels, out_channels, kernel_size),
|
| | nn.PixelShuffle(ratio),
|
| | ]
|
| |
|
| | self.uppath = nn.Sequential(*modules)
|
| |
|
| | def forward(self, x):
|
| | return self.uppath(x)
|
| |
|
| |
|
| | class RecLamResNet(nn.Module):
|
| | def __init__(self, args):
|
| | super(RecLamResNet, self).__init__()
|
| |
|
| | self.rgb_range = args.rgb_range
|
| | self.mean = self.rgb_range / 2
|
| | self.is_detach=args.detach
|
| |
|
| | self.n_resblocks = args.n_resblocks
|
| | self.n_feats = args.n_feats
|
| | self.kernel_size = args.kernel_size
|
| |
|
| | self.n_scales = args.n_scales
|
| |
|
| | self.body_model = ResNet(args, 3, 3, mean_shift=False)
|
| |
|
| | def forward(self, input_lst):
|
| |
|
| | input_lst[0] = input_lst[0] - self.mean
|
| | output_lst = [None] * self.n_scales
|
| | last_output = input_lst[0]
|
| | for i in range(self.n_scales):
|
| | if self.is_detach:
|
| | last_output=last_output.detach()
|
| | output = self.body_model(last_output) + last_output
|
| | output_lst[self.n_scales-i-1] = output + self.mean
|
| | last_output = output
|
| | return output_lst
|
| |
|