Spaces:
Runtime error
Runtime error
| from . import common | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange, repeat | |
| from models import register | |
| from .transformer import TransformerEncoder, TransformerDecoder | |
| from argparse import Namespace | |
| MIN_NUM_PATCHES = 12 | |
| def make_model(args, parent=False): | |
| return TransENet(args) | |
| class BasicModule(nn.Module): | |
| def __init__(self, conv, n_feat, kernel_size, block_type='basic', bias=True, | |
| bn=False, act=nn.ReLU(True)): | |
| super(BasicModule, self).__init__() | |
| self.block_type = block_type | |
| m_body = [] | |
| if block_type == 'basic': | |
| n_blocks = 10 | |
| m_body = [ | |
| common.BasicBlock(conv, n_feat, n_feat, kernel_size, bias=bias, bn=bn) | |
| # common.ResBlock(conv, n_feat, kernel_size) | |
| for _ in range(n_blocks) | |
| ] | |
| elif block_type == 'residual': | |
| n_blocks = 5 | |
| m_body = [ | |
| common.ResBlock(conv, n_feat, kernel_size) | |
| for _ in range(n_blocks) | |
| ] | |
| else: | |
| print('Error: not support this type') | |
| self.body = nn.Sequential(*m_body) | |
| def forward(self, x): | |
| res = self.body(x) | |
| if self.block_type == 'basic': | |
| out = res + x | |
| elif self.block_type == 'residual': | |
| out = res | |
| return out | |
| def TransENet(scale_ratio, n_feats=64, rgb_range=1): | |
| args = Namespace() | |
| args.n_feats = n_feats | |
| args.scale = [scale_ratio] | |
| args.patch_size = 48 * args.scale[0] | |
| args.rgb_range = rgb_range | |
| args.n_colors = 3 | |
| args.en_depth = 6 | |
| args.de_depth = 1 | |
| return TransENet(args) | |
| class TransENet(nn.Module): | |
| def __init__(self, args, conv=common.default_conv): | |
| super(TransENet, self).__init__() | |
| self.args = args | |
| self.scale = args.scale[0] | |
| n_feats = args.n_feats | |
| kernel_size = 3 | |
| act = nn.ReLU(True) | |
| # rgb_mean = (0.4916, 0.4991, 0.4565) # UCMerced data | |
| # rgb_std = (1.0, 1.0, 1.0) | |
| # | |
| # self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) | |
| # define head body | |
| m_head = [ | |
| conv(args.n_colors, n_feats, kernel_size), | |
| ] | |
| self.head = nn.Sequential(*m_head) | |
| # define main body | |
| self.feat_extrat_stage1 = BasicModule(conv, n_feats, kernel_size, block_type='residual', act=act) | |
| self.feat_extrat_stage2 = BasicModule(conv, n_feats, kernel_size, block_type='residual', act=act) | |
| self.feat_extrat_stage3 = BasicModule(conv, n_feats, kernel_size, block_type='residual', act=act) | |
| reduction = 4 | |
| self.stage1_conv1x1 = conv(n_feats, n_feats // reduction, 1) | |
| self.stage2_conv1x1 = conv(n_feats, n_feats // reduction, 1) | |
| self.stage3_conv1x1 = conv(n_feats, n_feats // reduction, 1) | |
| self.up_conv1x1 = conv(n_feats, n_feats // reduction, 1) | |
| self.span_conv1x1 = conv(n_feats // reduction, n_feats, 1) | |
| self.upsampler = common.Upsampler(conv, self.scale, n_feats, act=False) | |
| # define tail body | |
| self.tail = conv(n_feats, args.n_colors, kernel_size) | |
| # self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) | |
| # define transformer | |
| image_size = args.patch_size // self.scale | |
| patch_size = 4 | |
| dim = 512 | |
| en_depth = args.en_depth | |
| de_depth = args.de_depth | |
| heads = 6 | |
| mlp_dim = 512 | |
| channels = n_feats // reduction | |
| dim_head = 32 | |
| dropout = 0.0 | |
| assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' | |
| num_patches = (image_size // patch_size) ** 2 | |
| patch_dim = channels * patch_size ** 2 | |
| assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' | |
| self.patch_size = patch_size | |
| self.patch_to_embedding_low1 = nn.Linear(patch_dim, dim) | |
| self.patch_to_embedding_low2 = nn.Linear(patch_dim, dim) | |
| self.patch_to_embedding_low3 = nn.Linear(patch_dim, dim) | |
| self.patch_to_embedding_high = nn.Linear(patch_dim, dim) | |
| self.embedding_to_patch = nn.Linear(dim, patch_dim) | |
| self.encoder_stage1 = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout) | |
| self.encoder_stage2 = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout) | |
| self.encoder_stage3 = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout) | |
| self.encoder_up = TransformerEncoder(dim, en_depth, heads, dim_head, mlp_dim, dropout) | |
| self.decoder1 = TransformerDecoder(dim, de_depth, heads, dim_head, mlp_dim, dropout) | |
| self.decoder2 = TransformerDecoder(dim, de_depth, heads, dim_head, mlp_dim, dropout) | |
| self.decoder3 = TransformerDecoder(dim, de_depth, heads, dim_head, mlp_dim, dropout) | |
| def forward(self, x, out_size=None): | |
| # x = self.sub_mean(x) | |
| x = self.head(x) | |
| # feature extraction part | |
| feat_stage1 = self.feat_extrat_stage1(x) | |
| feat_stage2 = self.feat_extrat_stage2(x) | |
| feat_stage3 = self.feat_extrat_stage3(x) | |
| feat_ups = self.upsampler(feat_stage3) | |
| feat_stage1 = self.stage1_conv1x1(feat_stage1) | |
| feat_stage2 = self.stage2_conv1x1(feat_stage2) | |
| feat_stage3 = self.stage3_conv1x1(feat_stage3) | |
| feat_ups = self.up_conv1x1(feat_ups) | |
| # transformer part: | |
| p = self.patch_size | |
| feat_stage1 = rearrange(feat_stage1, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) | |
| feat_stage2 = rearrange(feat_stage2, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) | |
| feat_stage3 = rearrange(feat_stage3, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) | |
| feat_ups = rearrange(feat_ups, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) | |
| feat_stage1 = self.patch_to_embedding_low1(feat_stage1) | |
| feat_stage2 = self.patch_to_embedding_low2(feat_stage2) | |
| feat_stage3 = self.patch_to_embedding_low3(feat_stage3) | |
| feat_ups = self.patch_to_embedding_high(feat_ups) | |
| # encoder | |
| feat_stage1 = self.encoder_stage1(feat_stage1) | |
| feat_stage2 = self.encoder_stage2(feat_stage2) | |
| feat_stage3 = self.encoder_stage3(feat_stage3) | |
| feat_ups = self.encoder_up(feat_ups) | |
| feat_ups = self.decoder3(feat_ups, feat_stage3) | |
| feat_ups = self.decoder2(feat_ups, feat_stage2) | |
| feat_ups = self.decoder1(feat_ups, feat_stage1) | |
| feat_ups = self.embedding_to_patch(feat_ups) | |
| feat_ups = rearrange(feat_ups, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=self.args.patch_size // p, p1=p, p2=p) | |
| feat_ups = self.span_conv1x1(feat_ups) | |
| x = self.tail(feat_ups) | |
| # x = self.add_mean(x) | |
| return x | |
| def load_state_dict(self, state_dict, strict=False): | |
| own_state = self.state_dict() | |
| for name, param in state_dict.items(): | |
| if name in own_state: | |
| if isinstance(param, nn.Parameter): | |
| param = param.data | |
| try: | |
| own_state[name].copy_(param) | |
| except Exception: | |
| if name.find('tail') >= 0: | |
| print('Replace pre-trained upsampler to new one...') | |
| else: | |
| raise RuntimeError('While copying the parameter named {}, ' | |
| 'whose dimensions in the model are {} and ' | |
| 'whose dimensions in the checkpoint are {}.' | |
| .format(name, own_state[name].size(), param.size())) | |
| elif strict: | |
| if name.find('tail') == -1: | |
| raise KeyError('unexpected key "{}" in state_dict' | |
| .format(name)) | |
| if strict: | |
| missing = set(own_state.keys()) - set(state_dict.keys()) | |
| if len(missing) > 0: | |
| raise KeyError('missing keys in state_dict: "{}"'.format(missing)) | |
| if __name__ == "__main__": | |
| from option import args | |
| model = TransENet(args) | |
| model.eval() | |
| input = torch.rand(1, 3, 48, 48) | |
| sr = model(input) | |
| print(sr.size()) |