import torch import torch.nn as nn from . import common from .LamResNet import ResNet def build_model(args): return RecLamResNet(args) class conv_end(nn.Module): def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2): super(conv_end, self).__init__() modules = [ common.default_conv(in_channels, out_channels, kernel_size), nn.PixelShuffle(ratio), ] self.uppath = nn.Sequential(*modules) def forward(self, x): return self.uppath(x) class RecLamResNet(nn.Module): def __init__(self, args): super(RecLamResNet, self).__init__() self.rgb_range = args.rgb_range self.mean = self.rgb_range / 2 self.is_detach=args.detach self.n_resblocks = args.n_resblocks self.n_feats = args.n_feats self.kernel_size = args.kernel_size self.n_scales = args.n_scales self.body_model = ResNet(args, 3, 3, mean_shift=False) def forward(self, input_lst): # we use a reversed list for better compact input_lst[0] = input_lst[0] - self.mean output_lst = [None] * self.n_scales last_output = input_lst[0] for i in range(self.n_scales): if self.is_detach: last_output=last_output.detach() output = self.body_model(last_output) + last_output output_lst[self.n_scales-i-1] = output + self.mean last_output = output return output_lst