| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| def get_batchnorm_layer(opts): | |
| if opts.norm_layer == "batch": | |
| norm_layer = nn.BatchNorm2d | |
| elif opts.layer == "spectral_instance": | |
| norm_layer = nn.InstanceNorm2d | |
| else: | |
| print("not implemented") | |
| exit() | |
| return norm_layer | |
| def get_conv2d_layer(in_c, out_c, k, s, p=0, dilation=1, groups=1): | |
| return nn.Conv2d(in_channels=in_c, | |
| out_channels=out_c, | |
| kernel_size=k, | |
| stride=s, | |
| padding=p,dilation=dilation, groups=groups) | |
| def get_deconv2d_layer(in_c, out_c, k=1, s=1, p=1): | |
| return nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode="bilinear"), | |
| nn.Conv2d( | |
| in_channels=in_c, | |
| out_channels=out_c, | |
| kernel_size=k, | |
| stride=s, | |
| padding=p | |
| ) | |
| ) | |
| class Identity(nn.Module): | |
| def __init__(self): | |
| super(Identity, self).__init__() | |
| def forward(self, x): | |
| return x | |