|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import math |
|
|
import torch.nn.functional as F |
|
|
from torch.nn.utils import weight_norm |
|
|
|
|
|
|
|
|
def init_layer(L): |
|
|
|
|
|
if isinstance(L, nn.Conv2d): |
|
|
n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels |
|
|
L.weight.data.normal_(0,math.sqrt(2.0/float(n))) |
|
|
elif isinstance(L, nn.BatchNorm2d): |
|
|
L.weight.data.fill_(1) |
|
|
L.bias.data.fill_(0) |
|
|
|
|
|
class distLinear(nn.Module): |
|
|
def __init__(self, indim, outdim): |
|
|
super(distLinear, self).__init__() |
|
|
self.L = weight_norm(nn.Linear(indim, outdim, bias=False), name='weight', dim=0) |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
def forward(self, x): |
|
|
x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x) |
|
|
x_normalized = x.div(x_norm + 0.00001) |
|
|
L_norm = torch.norm(self.L.weight.data, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data) |
|
|
self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001) |
|
|
cos_dist = self.L(x_normalized) |
|
|
scores = 10 * cos_dist |
|
|
return scores |
|
|
|
|
|
|
|
|
class Flatten(nn.Module): |
|
|
def __init__(self): |
|
|
super(Flatten, self).__init__() |
|
|
|
|
|
def forward(self, x): |
|
|
return x.view(x.size(0), -1) |
|
|
|
|
|
|
|
|
class LSTMCell(nn.Module): |
|
|
maml = False |
|
|
def __init__(self, input_size, hidden_size, bias=True): |
|
|
super(LSTMCell, self).__init__() |
|
|
self.input_size = input_size |
|
|
self.hidden_size = hidden_size |
|
|
self.bias = bias |
|
|
if self.maml: |
|
|
self.x2h = Linear_fw(input_size, 4 * hidden_size, bias=bias) |
|
|
self.h2h = Linear_fw(hidden_size, 4 * hidden_size, bias=bias) |
|
|
else: |
|
|
self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias) |
|
|
self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias) |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self): |
|
|
std = 1.0 / math.sqrt(self.hidden_size) |
|
|
for w in self.parameters(): |
|
|
w.data.uniform_(-std, std) |
|
|
|
|
|
def forward(self, x, hidden=None): |
|
|
if hidden is None: |
|
|
hx = torch.zeors_like(x) |
|
|
cx = torch.zeros_like(x) |
|
|
else: |
|
|
hx, cx = hidden |
|
|
|
|
|
gates = self.x2h(x) + self.h2h(hx) |
|
|
ingate, forgetgate, cellgate, outgate = torch.split(gates, self.hidden_size, dim=1) |
|
|
|
|
|
ingate = torch.sigmoid(ingate) |
|
|
forgetgate = torch.sigmoid(forgetgate) |
|
|
cellgate = torch.tanh(cellgate) |
|
|
outgate = torch.sigmoid(outgate) |
|
|
|
|
|
cy = torch.mul(cx, forgetgate) + torch.mul(ingate, cellgate) |
|
|
hy = torch.mul(outgate, torch.tanh(cy)) |
|
|
return (hy, cy) |
|
|
|
|
|
|
|
|
class LSTM(nn.Module): |
|
|
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, bidirectional=False): |
|
|
super(LSTM, self).__init__() |
|
|
|
|
|
self.input_size = input_size |
|
|
self.hidden_size = hidden_size |
|
|
self.num_layers = num_layers |
|
|
self.bias = bias |
|
|
self.batch_first = batch_first |
|
|
self.num_directions = 2 if bidirectional else 1 |
|
|
assert(self.num_layers == 1) |
|
|
|
|
|
self.lstm = LSTMCell(input_size, hidden_size, self.bias) |
|
|
|
|
|
def forward(self, x, hidden=None): |
|
|
|
|
|
if self.batch_first: |
|
|
x = x.permute(1, 0 ,2) |
|
|
|
|
|
|
|
|
if hidden is None: |
|
|
h0 = torch.zeros(self.num_directions, x.size(1), self.hidden_size, dtype=x.dtype, device=x.device) |
|
|
c0 = torch.zeros(self.num_directions, x.size(1), self.hidden_size, dtype=x.dtype, device=x.device) |
|
|
else: |
|
|
h0, c0 = hidden |
|
|
|
|
|
|
|
|
outs = [] |
|
|
hn = h0[0] |
|
|
cn = c0[0] |
|
|
for seq in range(x.size(0)): |
|
|
hn, cn = self.lstm(x[seq], (hn, cn)) |
|
|
outs.append(hn.unsqueeze(0)) |
|
|
outs = torch.cat(outs, dim=0) |
|
|
|
|
|
|
|
|
if self.num_directions == 2: |
|
|
outs_reverse = [] |
|
|
hn = h0[1] |
|
|
cn = c0[1] |
|
|
for seq in range(x.size(0)): |
|
|
seq = x.size(1) - 1 - seq |
|
|
hn, cn = self.lstm(x[seq], (hn, cn)) |
|
|
outs_reverse.append(hn.unsqueeze(0)) |
|
|
outs_reverse = torch.cat(outs_reverse, dim=0) |
|
|
outs = torch.cat([outs, outs_reverse], dim=2) |
|
|
|
|
|
|
|
|
if self.batch_first: |
|
|
outs = outs.permute(1, 0, 2) |
|
|
return outs |
|
|
|
|
|
|
|
|
class Linear_fw(nn.Linear): |
|
|
def __init__(self, in_features, out_features, bias=True): |
|
|
super(Linear_fw, self).__init__(in_features, out_features, bias=bias) |
|
|
self.weight.fast = None |
|
|
self.bias.fast = None |
|
|
|
|
|
def forward(self, x): |
|
|
if self.weight.fast is not None and self.bias.fast is not None: |
|
|
out = F.linear(x, self.weight.fast, self.bias.fast) |
|
|
else: |
|
|
out = super(Linear_fw, self).forward(x) |
|
|
return out |
|
|
|
|
|
|
|
|
class Conv2d_fw(nn.Conv2d): |
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, bias = True): |
|
|
super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias) |
|
|
self.weight.fast = None |
|
|
if not self.bias is None: |
|
|
self.bias.fast = None |
|
|
|
|
|
def forward(self, x): |
|
|
if self.bias is None: |
|
|
if self.weight.fast is not None: |
|
|
out = F.conv2d(x, self.weight.fast, None, stride= self.stride, padding=self.padding) |
|
|
else: |
|
|
out = super(Conv2d_fw, self).forward(x) |
|
|
else: |
|
|
if self.weight.fast is not None and self.bias.fast is not None: |
|
|
out = F.conv2d(x, self.weight.fast, self.bias.fast, stride= self.stride, padding=self.padding) |
|
|
else: |
|
|
out = super(Conv2d_fw, self).forward(x) |
|
|
return out |
|
|
|
|
|
|
|
|
def softplus(x): |
|
|
return torch.nn.functional.softplus(x, beta=100) |
|
|
|
|
|
|
|
|
class FeatureWiseTransformation2d_fw(nn.BatchNorm2d): |
|
|
feature_augment = False |
|
|
def __init__(self, num_features, momentum=0.1, track_running_stats=True): |
|
|
super(FeatureWiseTransformation2d_fw, self).__init__(num_features, momentum=momentum, track_running_stats=track_running_stats) |
|
|
self.weight.fast = None |
|
|
self.bias.fast = None |
|
|
if self.track_running_stats: |
|
|
self.register_buffer('running_mean', torch.zeros(num_features)) |
|
|
self.register_buffer('running_var', torch.zeros(num_features)) |
|
|
if self.feature_augment: |
|
|
self.gamma = torch.nn.Parameter(torch.ones(1, num_features, 1, 1)*0.3) |
|
|
self.beta = torch.nn.Parameter(torch.ones(1, num_features, 1, 1)*0.5) |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_running_stats(self): |
|
|
if self.track_running_stats: |
|
|
self.running_mean.zero_() |
|
|
self.running_var.fill_(1) |
|
|
|
|
|
def forward(self, x, step=0): |
|
|
if self.weight.fast is not None and self.bias.fast is not None: |
|
|
weight = self.weight.fast |
|
|
bias = self.bias.fast |
|
|
else: |
|
|
weight = self.weight |
|
|
bias = self.bias |
|
|
if self.track_running_stats: |
|
|
out = F.batch_norm(x, self.running_mean, self.running_var, weight, bias, training=self.training, momentum=self.momentum) |
|
|
else: |
|
|
out = F.batch_norm(x, torch.zeros_like(x), torch.ones_like(x), weight, bias, training=True, momentum=1) |
|
|
|
|
|
|
|
|
if self.feature_augment and self.training: |
|
|
gamma = (1 + torch.randn(1, self.num_features, 1, 1, dtype=self.gamma.dtype, device=self.gamma.device)*softplus(self.gamma)).expand_as(out) |
|
|
beta = (torch.randn(1, self.num_features, 1, 1, dtype=self.beta.dtype, device=self.beta.device)*softplus(self.beta)).expand_as(out) |
|
|
out = gamma*out + beta |
|
|
return out |
|
|
|
|
|
|
|
|
class BatchNorm2d_fw(nn.BatchNorm2d): |
|
|
def __init__(self, num_features, momentum=0.1, track_running_stats=True): |
|
|
super(BatchNorm2d_fw, self).__init__(num_features, momentum=momentum, track_running_stats=track_running_stats) |
|
|
self.weight.fast = None |
|
|
self.bias.fast = None |
|
|
if self.track_running_stats: |
|
|
self.register_buffer('running_mean', torch.zeros(num_features)) |
|
|
self.register_buffer('running_var', torch.zeros(num_features)) |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_running_stats(self): |
|
|
if self.track_running_stats: |
|
|
self.running_mean.zero_() |
|
|
self.running_var.fill_(1) |
|
|
|
|
|
def forward(self, x, step=0): |
|
|
if self.weight.fast is not None and self.bias.fast is not None: |
|
|
weight = self.weight.fast |
|
|
bias = self.bias.fast |
|
|
else: |
|
|
weight = self.weight |
|
|
bias = self.bias |
|
|
if self.track_running_stats: |
|
|
out = F.batch_norm(x, self.running_mean, self.running_var, weight, bias, training=self.training, momentum=self.momentum) |
|
|
else: |
|
|
out = F.batch_norm(x, torch.zeros(x.size(1), dtype=x.dtype, device=x.device), torch.ones(x.size(1), dtype=x.dtype, device=x.device), weight, bias, training=True, momentum=1) |
|
|
return out |
|
|
|
|
|
|
|
|
class BatchNorm1d_fw(nn.BatchNorm1d): |
|
|
def __init__(self, num_features, momentum=0.1, track_running_stats=True): |
|
|
super(BatchNorm1d_fw, self).__init__(num_features, momentum=momentum, track_running_stats=track_running_stats) |
|
|
self.weight.fast = None |
|
|
self.bias.fast = None |
|
|
if self.track_running_stats: |
|
|
self.register_buffer('running_mean', torch.zeros(num_features)) |
|
|
self.register_buffer('running_var', torch.zeros(num_features)) |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_running_stats(self): |
|
|
if self.track_running_stats: |
|
|
self.running_mean.zero_() |
|
|
self.running_var.fill_(1) |
|
|
|
|
|
def forward(self, x, step=0): |
|
|
if self.weight.fast is not None and self.bias.fast is not None: |
|
|
weight = self.weight.fast |
|
|
bias = self.bias.fast |
|
|
else: |
|
|
weight = self.weight |
|
|
bias = self.bias |
|
|
if self.track_running_stats: |
|
|
out = F.batch_norm(x, self.running_mean, self.running_var, weight, bias, training=self.training, momentum=self.momentum) |
|
|
else: |
|
|
out = F.batch_norm(x, torch.zeros(x.size(1), dtype=x.dtype, device=x.device), torch.ones(x.size(1), dtype=x.dtype, device=x.device), weight, bias, training=True, momentum=1) |
|
|
return out |
|
|
|
|
|
|
|
|
class ConvBlock(nn.Module): |
|
|
maml = False |
|
|
def __init__(self, indim, outdim, pool = True, padding = 1): |
|
|
super(ConvBlock, self).__init__() |
|
|
self.indim = indim |
|
|
self.outdim = outdim |
|
|
if self.maml: |
|
|
self.C = Conv2d_fw(indim, outdim, 3, padding = padding) |
|
|
self.BN = FeatureWiseTransformation2d_fw(outdim) |
|
|
else: |
|
|
self.C = nn.Conv2d(indim, outdim, 3, padding= padding) |
|
|
self.BN = nn.BatchNorm2d(outdim) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
|
|
self.parametrized_layers = [self.C, self.BN, self.relu] |
|
|
if pool: |
|
|
self.pool = nn.MaxPool2d(2) |
|
|
self.parametrized_layers.append(self.pool) |
|
|
|
|
|
for layer in self.parametrized_layers: |
|
|
init_layer(layer) |
|
|
self.trunk = nn.Sequential(*self.parametrized_layers) |
|
|
|
|
|
def forward(self,x): |
|
|
out = self.trunk(x) |
|
|
return out |
|
|
|
|
|
|
|
|
class SimpleBlock(nn.Module): |
|
|
maml = False |
|
|
def __init__(self, indim, outdim, half_res, leaky=False): |
|
|
super(SimpleBlock, self).__init__() |
|
|
self.indim = indim |
|
|
self.outdim = outdim |
|
|
if self.maml: |
|
|
self.C1 = Conv2d_fw(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) |
|
|
self.BN1 = BatchNorm2d_fw(outdim) |
|
|
self.C2 = Conv2d_fw(outdim, outdim,kernel_size=3, padding=1,bias=False) |
|
|
self.BN2 = FeatureWiseTransformation2d_fw(outdim) |
|
|
else: |
|
|
self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) |
|
|
self.BN1 = nn.BatchNorm2d(outdim) |
|
|
self.C2 = nn.Conv2d(outdim, outdim,kernel_size=3, padding=1,bias=False) |
|
|
self.BN2 = nn.BatchNorm2d(outdim) |
|
|
self.relu1 = nn.ReLU(inplace=True) if not leaky else nn.LeakyReLU(0.2, inplace=True) |
|
|
self.relu2 = nn.ReLU(inplace=True) if not leaky else nn.LeakyReLU(0.2, inplace=True) |
|
|
|
|
|
self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2] |
|
|
|
|
|
self.half_res = half_res |
|
|
|
|
|
|
|
|
if indim!=outdim: |
|
|
if self.maml: |
|
|
self.shortcut = Conv2d_fw(indim, outdim, 1, 2 if half_res else 1, bias=False) |
|
|
self.BNshortcut = FeatureWiseTransformation2d_fw(outdim) |
|
|
else: |
|
|
self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False) |
|
|
self.BNshortcut = nn.BatchNorm2d(outdim) |
|
|
|
|
|
self.parametrized_layers.append(self.shortcut) |
|
|
self.parametrized_layers.append(self.BNshortcut) |
|
|
self.shortcut_type = '1x1' |
|
|
else: |
|
|
self.shortcut_type = 'identity' |
|
|
|
|
|
for layer in self.parametrized_layers: |
|
|
init_layer(layer) |
|
|
|
|
|
def forward(self, x): |
|
|
out = self.C1(x) |
|
|
out = self.BN1(out) |
|
|
out = self.relu1(out) |
|
|
out = self.C2(out) |
|
|
out = self.BN2(out) |
|
|
short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x)) |
|
|
out = out + short_out |
|
|
out = self.relu2(out) |
|
|
return out |
|
|
|
|
|
|
|
|
class ConvNet(nn.Module): |
|
|
def __init__(self, depth, flatten = True): |
|
|
super(ConvNet,self).__init__() |
|
|
self.grads = [] |
|
|
self.fmaps = [] |
|
|
trunk = [] |
|
|
for i in range(depth): |
|
|
indim = 3 if i == 0 else 64 |
|
|
outdim = 64 |
|
|
B = ConvBlock(indim, outdim, pool = ( i <4 ) ) |
|
|
trunk.append(B) |
|
|
|
|
|
if flatten: |
|
|
trunk.append(Flatten()) |
|
|
|
|
|
self.trunk = nn.Sequential(*trunk) |
|
|
self.final_feat_dim = 1600 |
|
|
|
|
|
def forward(self,x): |
|
|
out = self.trunk(x) |
|
|
return out |
|
|
|
|
|
|
|
|
class ConvNetNopool(nn.Module): |
|
|
def __init__(self, depth): |
|
|
super(ConvNetNopool,self).__init__() |
|
|
self.grads = [] |
|
|
self.fmaps = [] |
|
|
trunk = [] |
|
|
for i in range(depth): |
|
|
indim = 3 if i == 0 else 64 |
|
|
outdim = 64 |
|
|
B = ConvBlock(indim, outdim, pool = ( i in [0,1] ), padding = 0 if i in[0,1] else 1 ) |
|
|
trunk.append(B) |
|
|
|
|
|
self.trunk = nn.Sequential(*trunk) |
|
|
self.final_feat_dim = [64,19,19] |
|
|
|
|
|
def forward(self,x): |
|
|
out = self.trunk(x) |
|
|
return out |
|
|
|
|
|
|
|
|
class ResNet(nn.Module): |
|
|
maml = False |
|
|
def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten=True, leakyrelu=False): |
|
|
|
|
|
|
|
|
super(ResNet,self).__init__() |
|
|
self.grads = [] |
|
|
self.fmaps = [] |
|
|
assert len(list_of_num_layers)==4, 'Can have only four stages' |
|
|
if self.maml: |
|
|
conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
bn1 = BatchNorm2d_fw(64) |
|
|
else: |
|
|
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
bn1 = nn.BatchNorm2d(64) |
|
|
|
|
|
relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True) |
|
|
pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
|
|
|
init_layer(conv1) |
|
|
init_layer(bn1) |
|
|
|
|
|
trunk = [conv1, bn1, relu, pool1] |
|
|
|
|
|
indim = 64 |
|
|
for i in range(4): |
|
|
for j in range(list_of_num_layers[i]): |
|
|
half_res = (i>=1) and (j==0) |
|
|
B = block(indim, list_of_out_dims[i], half_res, leaky=leakyrelu) |
|
|
trunk.append(B) |
|
|
indim = list_of_out_dims[i] |
|
|
|
|
|
if flatten: |
|
|
avgpool = nn.AvgPool2d(7) |
|
|
trunk.append(avgpool) |
|
|
trunk.append(Flatten()) |
|
|
self.final_feat_dim = indim |
|
|
else: |
|
|
self.final_feat_dim = [ indim, 7, 7] |
|
|
|
|
|
self.trunk = nn.Sequential(*trunk) |
|
|
|
|
|
def forward(self,x): |
|
|
out = self.trunk(x) |
|
|
return out |
|
|
|
|
|
|
|
|
def forward_block1(self, x): |
|
|
out = self.trunk[:5](x) |
|
|
return out |
|
|
|
|
|
def forward_block2(self, x): |
|
|
out = self.trunk[5:6](x) |
|
|
return out |
|
|
|
|
|
def forward_block3(self, x): |
|
|
out = self.trunk[6:7](x) |
|
|
return out |
|
|
|
|
|
def forward_block4(self, x): |
|
|
out = self.trunk[7:8](x) |
|
|
return out |
|
|
''' |
|
|
def forward_block5(self, x): |
|
|
out = self.trunk[8:](x) |
|
|
return out |
|
|
''' |
|
|
def forward_rest(self,x): |
|
|
out = self.trunk[8:](x) |
|
|
return out |
|
|
|
|
|
|
|
|
class ResNet_Multi(nn.Module): |
|
|
maml = False |
|
|
def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten=True, leakyrelu=False): |
|
|
|
|
|
|
|
|
super(ResNet_Multi,self).__init__() |
|
|
self.grads = [] |
|
|
self.fmaps = [] |
|
|
assert len(list_of_num_layers)==4, 'Can have only four stages' |
|
|
if self.maml: |
|
|
conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
bn1 = BatchNorm2d_fw(64) |
|
|
else: |
|
|
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
bn1 = nn.BatchNorm2d(64) |
|
|
|
|
|
relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True) |
|
|
pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
|
|
|
init_layer(conv1) |
|
|
init_layer(bn1) |
|
|
|
|
|
trunk = [conv1, bn1, relu, pool1] |
|
|
|
|
|
indim = 64 |
|
|
for i in range(4): |
|
|
for j in range(list_of_num_layers[i]): |
|
|
half_res = (i>=1) and (j==0) |
|
|
B = block(indim, list_of_out_dims[i], half_res, leaky=leakyrelu) |
|
|
trunk.append(B) |
|
|
indim = list_of_out_dims[i] |
|
|
|
|
|
if flatten: |
|
|
avgpool = nn.AvgPool2d(7) |
|
|
trunk.append(avgpool) |
|
|
trunk.append(Flatten()) |
|
|
self.final_feat_dim = indim |
|
|
else: |
|
|
self.final_feat_dim = [ indim, 7, 7] |
|
|
|
|
|
self.trunk = nn.Sequential(*trunk) |
|
|
|
|
|
def forward(self,x): |
|
|
|
|
|
layer1 = self.trunk[:5](x) |
|
|
|
|
|
layer2 = self.trunk[5:6](layer1) |
|
|
|
|
|
layer3 = self.trunk[6:7](layer2) |
|
|
|
|
|
layer4 = self.trunk[7:8](layer3) |
|
|
|
|
|
out = self.trunk[8:](layer4) |
|
|
|
|
|
return layer1, layer2, layer3, layer4, out |
|
|
|
|
|
|
|
|
|
|
|
def Conv4(): |
|
|
return ConvNet(4) |
|
|
def Conv6(): |
|
|
return ConvNet(6) |
|
|
def Conv4NP(): |
|
|
return ConvNetNopool(4) |
|
|
def Conv6NP(): |
|
|
return ConvNetNopool(6) |
|
|
|
|
|
|
|
|
def ResNet10(flatten=True, leakyrelu=False): |
|
|
print('backbone:', 'return resnet10') |
|
|
return ResNet(SimpleBlock, [1,1,1,1],[64,128,256,512], flatten, leakyrelu) |
|
|
def ResNet10_Multi(flatten=True, leakyrelu=False): |
|
|
print('this is resnet10-multi') |
|
|
return ResNet_Multi(SimpleBlock, [1,1,1,1],[64,128,256,512], flatten, leakyrelu) |
|
|
def ResNet18(flatten=True, leakyrelu=False): |
|
|
return ResNet(SimpleBlock, [2,2,2,2],[64,128,256,512], flatten, leakyrelu) |
|
|
def ResNet34(flatten=True, leakyrelu=False): |
|
|
return ResNet(SimpleBlock, [3,4,6,3],[64,128,256,512], flatten, leakyrelu) |
|
|
|
|
|
model_dict = dict(Conv4 = Conv4, |
|
|
Conv6 = Conv6, |
|
|
ResNet10 = ResNet10, |
|
|
ResNet10_Multi = ResNet10_Multi, |
|
|
ResNet18 = ResNet18, |
|
|
ResNet34 = ResNet34) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
model_func = model_dict['ResNet10'] |
|
|
net = model_func(flatten = True, leakyrelu= False) |
|
|
from torch.autograd import Variable |
|
|
x = Variable(torch.randn([16,3,224,224])) |
|
|
out = net(x) |
|
|
print(out.size()) |
|
|
|
|
|
|
|
|
model_func_2 = model_dict['ResNet10_Multi'] |
|
|
net2 = model_func_2(flatten = True, leakyrelu = False) |
|
|
layer1, layer2, layer3, layer4, out2 = net2(x) |
|
|
print('net-multi:', layer1.size(), layer2.size(), layer3.size(), layer4.size(), out2.size()) |
|
|
|
|
|
|
|
|
|
|
|
print('------------------') |
|
|
model_func = model_dict['ResNet10'] |
|
|
net = model_func(flatten = True, leakyrelu= False) |
|
|
from torch.autograd import Variable |
|
|
x = Variable(torch.randn([16,3,224,224])) |
|
|
out = net(x) |
|
|
print(out.size()) |
|
|
|
|
|
print(net) |
|
|
block1 = net.forward_block1(x) |
|
|
print('block1:', block1.size()) |
|
|
|
|
|
block2 = net.forward_block2(block1) |
|
|
print('block2:', block2.size()) |
|
|
|
|
|
block3 = net.forward_block3(block2) |
|
|
print('block3:', block3.size()) |
|
|
|
|
|
block4 = net.forward_block4(block3) |
|
|
print('block4:', block4.size()) |
|
|
|
|
|
block5 = net.forward_block5(block4) |
|
|
print('block5:', block5.size()) |
|
|
|