Spaces:
Runtime error
Runtime error
| from . import common | |
| import torch.nn as nn | |
| def make_model(args, parent=False): | |
| return BASIC(args) | |
| class BASIC(nn.Module): | |
| def __init__(self, args, conv=common.default_conv): | |
| super(BASIC, self).__init__() | |
| n_resblocks = args.n_resblocks | |
| n_feats = args.n_feats | |
| kernel_size = 3 | |
| scale = args.scale[0] | |
| act = nn.ReLU(True) | |
| # define head module | |
| m_head = [conv(args.n_colors, n_feats, kernel_size)] | |
| # define body module | |
| m_body = [ | |
| common.ResBlock( | |
| conv, n_feats, kernel_size, act=act, res_scale=args.res_scale | |
| ) for _ in range(n_resblocks) | |
| ] | |
| m_body.append(conv(n_feats, n_feats, kernel_size)) | |
| # define tail module | |
| m_tail = [ | |
| common.Upsampler(conv, scale, n_feats), | |
| conv(n_feats, args.n_colors, kernel_size) | |
| ] | |
| self.head = nn.Sequential(*m_head) | |
| self.body = nn.Sequential(*m_body) | |
| self.tail = nn.Sequential(*m_tail) | |
| def forward(self, x): | |
| x = self.head(x) | |
| res = self.body(x) | |
| res += x | |
| x = self.tail(x) | |
| return x | |