File size: 1,078 Bytes
4336727 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import torch
import torch.nn as nn
import torchvision
def get_batchnorm_layer(opts):
if opts.norm_layer == "batch":
norm_layer = nn.BatchNorm2d
elif opts.layer == "spectral_instance":
norm_layer = nn.InstanceNorm2d
else:
print("not implemented")
exit()
return norm_layer
def get_conv2d_layer(in_c, out_c, k, s, p=0, dilation=1, groups=1):
return nn.Conv2d(in_channels=in_c,
out_channels=out_c,
kernel_size=k,
stride=s,
padding=p,dilation=dilation, groups=groups)
def get_deconv2d_layer(in_c, out_c, k=1, s=1, p=1):
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear"),
nn.Conv2d(
in_channels=in_c,
out_channels=out_c,
kernel_size=k,
stride=s,
padding=p
)
)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
|