File size: 2,588 Bytes
95b1715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
RESNET_MAPPING = {
    "layer1.0": "body.0",
    "layer1.1": "body.1",
    "layer1.2": "body.2",
    "layer2.0": "body.3",
    "layer2.1": "body.4",
    "layer2.2": "body.5",
    "layer2.3": "body.6",
    "layer3.0": "body.7",
    "layer3.1": "body.8",
    "layer3.2": "body.9",
    "layer3.3": "body.10",
    "layer3.4": "body.11",
    "layer3.5": "body.12",
    "layer4.0": "body.13",
    "layer4.1": "body.14",
    "layer4.2": "body.15",
}


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def toogle_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag


def stylegan_to_classifier(x, out_size=(224, 224)):
    """Clip image to range(0,1)"""
    img_tmp = x.clone()
    img_tmp = torch.clamp((0.5*img_tmp + 0.5), 0, 1)
    img_tmp = F.interpolate(img_tmp, size=out_size, mode='bilinear')
    img_tmp[:,0] = (img_tmp[:,0] - 0.485)/0.229
    img_tmp[:,1] = (img_tmp[:,1] - 0.456)/0.224
    img_tmp[:,2] = (img_tmp[:,2] - 0.406)/0.225
    return img_tmp


def get_stylespace_from_w(w, G):
    style_space = []
    to_rgb_stylespaces = []

    noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)]
    latent = w
    style_space.append(G.conv1.conv.modulation(latent[:, 0]))
    to_rgb_stylespaces.append(G.to_rgb1.conv.modulation(latent[:, 1]))

    i = 1
    for conv1, conv2, noise1, noise2, to_rgb in zip(
            G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs
    ):
        style_space.append(conv1.conv.modulation(latent[:, i]))
        style_space.append(conv2.conv.modulation(latent[:, i + 1]))
        to_rgb_stylespaces.append(to_rgb.conv.modulation(latent[:, i + 2]))
        i += 2
    return style_space, to_rgb_stylespaces


def get_stylespace_from_w_hyperinv(w, G):
    with torch.no_grad():
        style_space = []
        to_rgb_stylespaces = []
        G = G.synthesis

        block_ws = []
        w_idx = 0
        for res in G.block_resolutions:
            block = getattr(G, f"b{res}")
            block_ws.append(w.narrow(1, w_idx, block.num_conv + block.num_torgb))
            w_idx += block.num_conv

        i = 0
        for res, cur_ws in zip(G.block_resolutions, block_ws):
            block = getattr(G, f"b{res}")
            if i != 0:
                style_space.append(block.conv0.affine(w[:, i]))
                i += 1
            style_space.append(block.conv1.affine(w[:, i]))
            i += 1
            to_rgb_stylespaces.append(block.torgb.affine(w[:, i]))

    return style_space, to_rgb_stylespaces