File size: 4,349 Bytes
95b1715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from collections import OrderedDict

import torch
import torch.nn as nn


def shape_to_num_params(shapes):
    return torch.sum(torch.tensor([torch.prod(s) for s in shapes])).int().item()


class WeightRegressor(nn.Module):
    """Regressing features to convolution weight kernel"""

    def __init__(self, input_dim, hidden_dim, kernel_size=3, out_channels=16, in_channels=16):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.in_channels = in_channels

        # Feature Transformer
        self.fusion = nn.Sequential(
            nn.Conv2d(2 * self.input_dim, self.input_dim, kernel_size=1, padding=0, stride=1, bias=True),
            nn.InstanceNorm2d(self.input_dim),
            nn.ReLU(),
        )
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, kernel_size=3, padding=1, stride=1, bias=True),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),
            nn.ReLU(),
            nn.Conv2d(64, self.hidden_dim, kernel_size=4, stride=2, padding=1, bias=True),
            nn.ReLU(),
        )

        # Linear Mapper
        self.w1 = nn.Parameter(torch.randn((self.hidden_dim, self.in_channels * self.hidden_dim)))
        self.b1 = nn.Parameter(torch.randn((self.in_channels * self.hidden_dim)))
        self.w2 = nn.Parameter(torch.randn((self.hidden_dim, self.out_channels * self.kernel_size * self.kernel_size)))
        self.b2 = nn.Parameter(torch.randn((self.out_channels * self.kernel_size * self.kernel_size)))

        self.weight_init()

    def weight_init(self):
        nn.init.kaiming_normal_(self.w1)
        nn.init.zeros_(self.b1)
        nn.init.kaiming_normal_(self.w2)
        nn.init.zeros_(self.b2)

    def forward(self, w_image_codes, w_bar_codes):
        bs = w_image_codes.size(0)

        # Feature Transformation
        out = self.fusion(torch.cat((w_image_codes, w_bar_codes), 1))
        out = self.feature_extractor(out)
        out = out.view(bs, -1)

        # Linear map to weights
        out = torch.matmul(out, self.w1) + self.b1
        out = out.view(bs, self.in_channels, self.hidden_dim)
        out = torch.matmul(out, self.w2) + self.b2
        kernel = out.view(bs, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)

        return kernel


class Hypernetwork(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=64, target_shape=None):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.target_shape = target_shape

        num_predicted_weights = 0
        weight_regressors = OrderedDict()
        for layer_name in target_shape:
            new_layer_name = "_".join(layer_name.split("."))
            shape = target_shape[layer_name]["shape"]

            # without consider bias
            if len(shape) == 4:
                out_channels, in_channels, kernel_size = shape[:3]
            else:
                out_channels, in_channels = shape
                kernel_size = 1

            num_predicted_weights += shape_to_num_params([torch.tensor(list(shape))])
            weight_regressors[new_layer_name] = WeightRegressor(
                input_dim=self.input_dim,
                hidden_dim=self.hidden_dim,
                kernel_size=kernel_size,
                out_channels=out_channels,
                in_channels=in_channels,
            )
        self.weight_regressors = nn.ModuleDict(weight_regressors)
        self.num_predicted_weights = num_predicted_weights

    def forward(self, w_image_codes, w_bar_codes):
        bs = w_image_codes.size(0)
        out_weights = {}

        for layer_name in self.weight_regressors:
            ori_layer_name = ".".join(layer_name.split("_"))
            w_idx = self.target_shape[ori_layer_name]["w_idx"]
            weights = self.weight_regressors[layer_name](
                w_image_codes[:, w_idx, :, :, :], w_bar_codes[:, w_idx, :, :, :]
            )
            out_weights[ori_layer_name] = weights.view(bs, *list(self.target_shape[ori_layer_name]["shape"]))

        return out_weights