Spaces:
Sleeping
Sleeping
File size: 4,349 Bytes
95b1715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from collections import OrderedDict
import torch
import torch.nn as nn
def shape_to_num_params(shapes):
return torch.sum(torch.tensor([torch.prod(s) for s in shapes])).int().item()
class WeightRegressor(nn.Module):
"""Regressing features to convolution weight kernel"""
def __init__(self, input_dim, hidden_dim, kernel_size=3, out_channels=16, in_channels=16):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.out_channels = out_channels
self.in_channels = in_channels
# Feature Transformer
self.fusion = nn.Sequential(
nn.Conv2d(2 * self.input_dim, self.input_dim, kernel_size=1, padding=0, stride=1, bias=True),
nn.InstanceNorm2d(self.input_dim),
nn.ReLU(),
)
self.feature_extractor = nn.Sequential(
nn.Conv2d(self.input_dim, 64, kernel_size=3, padding=1, stride=1, bias=True),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(64, self.hidden_dim, kernel_size=4, stride=2, padding=1, bias=True),
nn.ReLU(),
)
# Linear Mapper
self.w1 = nn.Parameter(torch.randn((self.hidden_dim, self.in_channels * self.hidden_dim)))
self.b1 = nn.Parameter(torch.randn((self.in_channels * self.hidden_dim)))
self.w2 = nn.Parameter(torch.randn((self.hidden_dim, self.out_channels * self.kernel_size * self.kernel_size)))
self.b2 = nn.Parameter(torch.randn((self.out_channels * self.kernel_size * self.kernel_size)))
self.weight_init()
def weight_init(self):
nn.init.kaiming_normal_(self.w1)
nn.init.zeros_(self.b1)
nn.init.kaiming_normal_(self.w2)
nn.init.zeros_(self.b2)
def forward(self, w_image_codes, w_bar_codes):
bs = w_image_codes.size(0)
# Feature Transformation
out = self.fusion(torch.cat((w_image_codes, w_bar_codes), 1))
out = self.feature_extractor(out)
out = out.view(bs, -1)
# Linear map to weights
out = torch.matmul(out, self.w1) + self.b1
out = out.view(bs, self.in_channels, self.hidden_dim)
out = torch.matmul(out, self.w2) + self.b2
kernel = out.view(bs, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
return kernel
class Hypernetwork(nn.Module):
def __init__(self, input_dim=512, hidden_dim=64, target_shape=None):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.target_shape = target_shape
num_predicted_weights = 0
weight_regressors = OrderedDict()
for layer_name in target_shape:
new_layer_name = "_".join(layer_name.split("."))
shape = target_shape[layer_name]["shape"]
# without consider bias
if len(shape) == 4:
out_channels, in_channels, kernel_size = shape[:3]
else:
out_channels, in_channels = shape
kernel_size = 1
num_predicted_weights += shape_to_num_params([torch.tensor(list(shape))])
weight_regressors[new_layer_name] = WeightRegressor(
input_dim=self.input_dim,
hidden_dim=self.hidden_dim,
kernel_size=kernel_size,
out_channels=out_channels,
in_channels=in_channels,
)
self.weight_regressors = nn.ModuleDict(weight_regressors)
self.num_predicted_weights = num_predicted_weights
def forward(self, w_image_codes, w_bar_codes):
bs = w_image_codes.size(0)
out_weights = {}
for layer_name in self.weight_regressors:
ori_layer_name = ".".join(layer_name.split("_"))
w_idx = self.target_shape[ori_layer_name]["w_idx"]
weights = self.weight_regressors[layer_name](
w_image_codes[:, w_idx, :, :, :], w_bar_codes[:, w_idx, :, :, :]
)
out_weights[ori_layer_name] = weights.view(bs, *list(self.target_shape[ori_layer_name]["shape"]))
return out_weights |