import torch import torch.nn as nn import torchvision def get_batchnorm_layer(opts): if opts.norm_layer == "batch": norm_layer = nn.BatchNorm2d elif opts.layer == "spectral_instance": norm_layer = nn.InstanceNorm2d else: print("not implemented") exit() return norm_layer def get_conv2d_layer(in_c, out_c, k, s, p=0, dilation=1, groups=1): return nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=k, stride=s, padding=p,dilation=dilation, groups=groups) def get_deconv2d_layer(in_c, out_c, k=1, s=1, p=1): return nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear"), nn.Conv2d( in_channels=in_c, out_channels=out_c, kernel_size=k, stride=s, padding=p ) ) class Identity(nn.Module): def __init__(self): super(Identity, self).__init__() def forward(self, x): return x