File size: 4,648 Bytes
0f52c9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import re
import torch.nn as nn
from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
import torch.nn.utils.spectral_norm as spectral_norm
import torch


# Returns a function that creates a standard normalization function
def get_norm_layer(opt, norm_type='instance'):
    # helper function to get # output channels of the previous layer
    def get_out_channel(layer):
        if hasattr(layer, 'out_channels'):
            return getattr(layer, 'out_channels')
        return layer.weight.size(0)

    # this function will be returned
    def add_norm_layer(layer):
        nonlocal norm_type
        if norm_type.startswith('spectral'):
            layer = spectral_norm(layer)
            subnorm_type = norm_type[len('spectral'):]

        if subnorm_type == 'none' or len(subnorm_type) == 0:
            return layer

        # remove bias in the previous layer, which is meaningless
        # since it has no effect after normalization
        if getattr(layer, 'bias', None) is not None:
            delattr(layer, 'bias')
            layer.register_parameter('bias', None)

        if subnorm_type == 'batch':
            norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
        elif subnorm_type == 'syncbatch':
            norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
        elif subnorm_type == 'instance':
            norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
        else:
            raise ValueError('normalization layer %s is not recognized' % subnorm_type)

        return nn.Sequential(layer, norm_layer)

    return add_norm_layer


# Creates FADE normalization layer based on the given configuration
class FADE(nn.Module):
    def __init__(self, config_text, norm_nc, label_nc):
        super().__init__()

        assert config_text.startswith('fade')
        parsed = re.search('fade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))

        if param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'syncbatch':
            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError('%s is not a recognized param-free norm type in FADE'
                             % param_free_norm_type)

        pw = ks // 2
        self.mlp_gamma = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw)
        self.mlp_beta = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw)

    def forward(self, x, feat):
        # Step 1. generate parameter-free normalized activations
        normalized = self.param_free_norm(x)

        # Step 2. produce scale and bias conditioned on feature map
        gamma = self.mlp_gamma(feat)
        beta = self.mlp_beta(feat)

        # Step 3. apply scale and bias
        out = normalized * (1 + gamma) + beta

        return out

class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, in_channel, opt):
        super().__init__()
        self.norm = nn.InstanceNorm2d(in_channel)
        self.gamma = nn.Conv2d(in_channel, in_channel, 1)
        self.beta = nn.Conv2d(in_channel, in_channel, 1)
        self.conv_l = nn.Conv2d(2 * opt.latent_dim, in_channel, 1)
        self.relu = nn.ReLU()
    def forward(self, input, latent_code, alpha):
        size = latent_code.size()
        alpha = alpha.expand(size)
        latent_code = torch.cat([latent_code, alpha], 1)
        latent_code = latent_code.unsqueeze(2).unsqueeze(3)
        latent_code = self.relu(self.conv_l(latent_code))
        gamma = self.gamma(latent_code)
        beta = self.beta(latent_code)
        out = self.norm(input)
        out = out * (1 + gamma) + beta
        return out

# class AdaptiveInstanceNorm(nn.Module):
#     def __init__(self, in_channel, opt):
#         super().__init__()
#         self.norm = nn.InstanceNorm2d(in_channel)
#         self.gamma = nn.Conv2d(in_channel, in_channel, 1)
#         self.beta = nn.Conv2d(in_channel, in_channel, 1)
#         self.conv1x1 = nn.Conv2d(in_channel, in_channel, 1)
#         self.relu = nn.ReLU()
#     def forward(self, input, style, alpha):
#         size = input.size()
#         alpha = alpha.expand(size)
#         gamma = self.gamma(alpha)
#         beta = self.beta(alpha)
#         normalized = self.norm(input)
#         out = normalized * gamma +  beta
#         out = self.relu(self.conv1x1(out))
#         return out