Spaces:
Running
Running
| import torch.nn as nn | |
| import models.basicblock as B | |
| import torch | |
| """ | |
| # -------------------------------------------- | |
| # SRMD (15 conv layers) | |
| # -------------------------------------------- | |
| Reference: | |
| @inproceedings{zhang2018learning, | |
| title={Learning a single convolutional super-resolution network for multiple degradations}, | |
| author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, | |
| booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, | |
| pages={3262--3271}, | |
| year={2018} | |
| } | |
| http://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Learning_a_Single_CVPR_2018_paper.pdf | |
| """ | |
| # -------------------------------------------- | |
| # SRMD (SRMD, in_nc = 3+15+1 = 19) | |
| # SRMD (SRMDNF, in_nc = 3+15 = 18) | |
| # -------------------------------------------- | |
| class SRMD(nn.Module): | |
| def __init__(self, in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle'): | |
| """ | |
| # ------------------------------------ | |
| in_nc: channel number of input, default: 3+15 | |
| out_nc: channel number of output | |
| nc: channel number | |
| nb: total number of conv layers | |
| upscale: scale factor | |
| act_mode: batch norm + activation function; 'BR' means BN+ReLU | |
| upsample_mode: default 'pixelshuffle' = conv + pixelshuffle | |
| # ------------------------------------ | |
| """ | |
| super(SRMD, self).__init__() | |
| assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL' | |
| bias = True | |
| if upsample_mode == 'upconv': | |
| upsample_block = B.upsample_upconv | |
| elif upsample_mode == 'pixelshuffle': | |
| upsample_block = B.upsample_pixelshuffle | |
| elif upsample_mode == 'convtranspose': | |
| upsample_block = B.upsample_convtranspose | |
| else: | |
| raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) | |
| m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias) | |
| m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)] | |
| m_tail = upsample_block(nc, out_nc, mode=str(upscale), bias=bias) | |
| self.model = B.sequential(m_head, *m_body, m_tail) | |
| # def forward(self, x, k_pca): | |
| # m = k_pca.repeat(1, 1, x.size()[-2], x.size()[-1]) | |
| # x = torch.cat((x, m), 1) | |
| # x = self.body(x) | |
| def forward(self, x): | |
| x = self.model(x) | |
| return x | |
| if __name__ == '__main__': | |
| from utils import utils_model | |
| model = SRMD(in_nc=18, out_nc=3, nc=64, nb=15, upscale=4, act_mode='R', upsample_mode='pixelshuffle') | |
| print(utils_model.describe_model(model)) | |
| x = torch.randn((2, 3, 100, 100)) | |
| k_pca = torch.randn(2, 15, 1, 1) | |
| x = model(x, k_pca) | |
| print(x.shape) | |
| # run models/network_srmd.py | |