Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class ImageEncoder(nn.Module): | |
| def __init__(self, img_size, input_nc, ngf=16, norm_layer=nn.LayerNorm): | |
| super(ImageEncoder, self).__init__() | |
| n_downsampling = int(math.log(img_size, 2)) | |
| ks_list = [5] * (n_downsampling - n_downsampling // 3) + [3] * (n_downsampling // 3) | |
| stride_list = [2] * n_downsampling | |
| chn_mult = [] | |
| for i in range(n_downsampling): | |
| chn_mult.append(2 ** (i + 1)) | |
| encoder = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=7 // 2, bias=True, padding_mode='replicate'), | |
| norm_layer([ngf, 2 ** n_downsampling, 2 ** n_downsampling]), | |
| nn.ReLU(True)] | |
| for i in range(n_downsampling): # add downsampling layers | |
| if i == 0: | |
| chn_prev = ngf | |
| else: | |
| chn_prev = ngf * chn_mult[i - 1] | |
| chn_next = ngf * chn_mult[i] | |
| encoder += [nn.Conv2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, padding_mode='replicate'), | |
| norm_layer([chn_next, 2 ** (n_downsampling - 1 - i), 2 ** (n_downsampling - 1 - i)]), | |
| nn.ReLU(True)] | |
| self.encode = nn.Sequential(*encoder) | |
| self.flatten = nn.Flatten() | |
| def forward(self, input): | |
| """Standard forward""" | |
| ret = self.encode(input) | |
| img_feat = self.flatten(ret) | |
| output = {} | |
| output['img_feat'] = img_feat | |
| return output | |