from collections import OrderedDict import torch import torch.nn as nn def shape_to_num_params(shapes): return torch.sum(torch.tensor([torch.prod(s) for s in shapes])).int().item() class WeightRegressor(nn.Module): """Regressing features to convolution weight kernel""" def __init__(self, input_dim, hidden_dim, kernel_size=3, out_channels=16, in_channels=16): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.kernel_size = kernel_size self.out_channels = out_channels self.in_channels = in_channels # Feature Transformer self.fusion = nn.Sequential( nn.Conv2d(2 * self.input_dim, self.input_dim, kernel_size=1, padding=0, stride=1, bias=True), nn.InstanceNorm2d(self.input_dim), nn.ReLU(), ) self.feature_extractor = nn.Sequential( nn.Conv2d(self.input_dim, 64, kernel_size=3, padding=1, stride=1, bias=True), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True), nn.ReLU(), nn.Conv2d(64, self.hidden_dim, kernel_size=4, stride=2, padding=1, bias=True), nn.ReLU(), ) # Linear Mapper self.w1 = nn.Parameter(torch.randn((self.hidden_dim, self.in_channels * self.hidden_dim))) self.b1 = nn.Parameter(torch.randn((self.in_channels * self.hidden_dim))) self.w2 = nn.Parameter(torch.randn((self.hidden_dim, self.out_channels * self.kernel_size * self.kernel_size))) self.b2 = nn.Parameter(torch.randn((self.out_channels * self.kernel_size * self.kernel_size))) self.weight_init() def weight_init(self): nn.init.kaiming_normal_(self.w1) nn.init.zeros_(self.b1) nn.init.kaiming_normal_(self.w2) nn.init.zeros_(self.b2) def forward(self, w_image_codes, w_bar_codes): bs = w_image_codes.size(0) # Feature Transformation out = self.fusion(torch.cat((w_image_codes, w_bar_codes), 1)) out = self.feature_extractor(out) out = out.view(bs, -1) # Linear map to weights out = torch.matmul(out, self.w1) + self.b1 out = out.view(bs, self.in_channels, self.hidden_dim) out = torch.matmul(out, self.w2) + self.b2 kernel = out.view(bs, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size) return kernel class Hypernetwork(nn.Module): def __init__(self, input_dim=512, hidden_dim=64, target_shape=None): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.target_shape = target_shape num_predicted_weights = 0 weight_regressors = OrderedDict() for layer_name in target_shape: new_layer_name = "_".join(layer_name.split(".")) shape = target_shape[layer_name]["shape"] # without consider bias if len(shape) == 4: out_channels, in_channels, kernel_size = shape[:3] else: out_channels, in_channels = shape kernel_size = 1 num_predicted_weights += shape_to_num_params([torch.tensor(list(shape))]) weight_regressors[new_layer_name] = WeightRegressor( input_dim=self.input_dim, hidden_dim=self.hidden_dim, kernel_size=kernel_size, out_channels=out_channels, in_channels=in_channels, ) self.weight_regressors = nn.ModuleDict(weight_regressors) self.num_predicted_weights = num_predicted_weights def forward(self, w_image_codes, w_bar_codes): bs = w_image_codes.size(0) out_weights = {} for layer_name in self.weight_regressors: ori_layer_name = ".".join(layer_name.split("_")) w_idx = self.target_shape[ori_layer_name]["w_idx"] weights = self.weight_regressors[layer_name]( w_image_codes[:, w_idx, :, :, :], w_bar_codes[:, w_idx, :, :, :] ) out_weights[ori_layer_name] = weights.view(bs, *list(self.target_shape[ori_layer_name]["shape"])) return out_weights