CausalStyleAdv / methods /backbone.py
YuqianFu's picture
Upload folder using huggingface_hub
197d4ca verified
# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.nn.utils import weight_norm
# --- gaussian initialize ---
def init_layer(L):
# Initialization using fan-in
if isinstance(L, nn.Conv2d):
n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
elif isinstance(L, nn.BatchNorm2d):
L.weight.data.fill_(1)
L.bias.data.fill_(0)
class distLinear(nn.Module):
def __init__(self, indim, outdim):
super(distLinear, self).__init__()
self.L = weight_norm(nn.Linear(indim, outdim, bias=False), name='weight', dim=0)
self.relu = nn.ReLU()
def forward(self, x):
x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x)
x_normalized = x.div(x_norm + 0.00001)
L_norm = torch.norm(self.L.weight.data, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data)
self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001)
cos_dist = self.L(x_normalized)
scores = 10 * cos_dist
return scores
# --- flatten tensor ---
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
# --- LSTMCell module for matchingnet ---
class LSTMCell(nn.Module):
maml = False
def __init__(self, input_size, hidden_size, bias=True):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
if self.maml:
self.x2h = Linear_fw(input_size, 4 * hidden_size, bias=bias)
self.h2h = Linear_fw(hidden_size, 4 * hidden_size, bias=bias)
else:
self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
self.reset_parameters()
def reset_parameters(self):
std = 1.0 / math.sqrt(self.hidden_size)
for w in self.parameters():
w.data.uniform_(-std, std)
def forward(self, x, hidden=None):
if hidden is None:
hx = torch.zeors_like(x)
cx = torch.zeros_like(x)
else:
hx, cx = hidden
gates = self.x2h(x) + self.h2h(hx)
ingate, forgetgate, cellgate, outgate = torch.split(gates, self.hidden_size, dim=1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = torch.mul(cx, forgetgate) + torch.mul(ingate, cellgate)
hy = torch.mul(outgate, torch.tanh(cy))
return (hy, cy)
# --- LSTM module for matchingnet ---
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, bidirectional=False):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.num_directions = 2 if bidirectional else 1
assert(self.num_layers == 1)
self.lstm = LSTMCell(input_size, hidden_size, self.bias)
def forward(self, x, hidden=None):
# swap axis if batch first
if self.batch_first:
x = x.permute(1, 0 ,2)
# hidden state
if hidden is None:
h0 = torch.zeros(self.num_directions, x.size(1), self.hidden_size, dtype=x.dtype, device=x.device)
c0 = torch.zeros(self.num_directions, x.size(1), self.hidden_size, dtype=x.dtype, device=x.device)
else:
h0, c0 = hidden
# forward
outs = []
hn = h0[0]
cn = c0[0]
for seq in range(x.size(0)):
hn, cn = self.lstm(x[seq], (hn, cn))
outs.append(hn.unsqueeze(0))
outs = torch.cat(outs, dim=0)
# reverse foward
if self.num_directions == 2:
outs_reverse = []
hn = h0[1]
cn = c0[1]
for seq in range(x.size(0)):
seq = x.size(1) - 1 - seq
hn, cn = self.lstm(x[seq], (hn, cn))
outs_reverse.append(hn.unsqueeze(0))
outs_reverse = torch.cat(outs_reverse, dim=0)
outs = torch.cat([outs, outs_reverse], dim=2)
# swap axis if batch first
if self.batch_first:
outs = outs.permute(1, 0, 2)
return outs
# --- Linear module ---
class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight
def __init__(self, in_features, out_features, bias=True):
super(Linear_fw, self).__init__(in_features, out_features, bias=bias)
self.weight.fast = None #Lazy hack to add fast weight link
self.bias.fast = None
def forward(self, x):
if self.weight.fast is not None and self.bias.fast is not None:
out = F.linear(x, self.weight.fast, self.bias.fast)
else:
out = super(Linear_fw, self).forward(x)
return out
# --- Conv2d module ---
class Conv2d_fw(nn.Conv2d): #used in MAML to forward input with fast weight
def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, bias = True):
super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias)
self.weight.fast = None
if not self.bias is None:
self.bias.fast = None
def forward(self, x):
if self.bias is None:
if self.weight.fast is not None:
out = F.conv2d(x, self.weight.fast, None, stride= self.stride, padding=self.padding)
else:
out = super(Conv2d_fw, self).forward(x)
else:
if self.weight.fast is not None and self.bias.fast is not None:
out = F.conv2d(x, self.weight.fast, self.bias.fast, stride= self.stride, padding=self.padding)
else:
out = super(Conv2d_fw, self).forward(x)
return out
# --- softplus module ---
def softplus(x):
return torch.nn.functional.softplus(x, beta=100)
# --- feature-wise transformation layer ---
class FeatureWiseTransformation2d_fw(nn.BatchNorm2d):
feature_augment = False
def __init__(self, num_features, momentum=0.1, track_running_stats=True):
super(FeatureWiseTransformation2d_fw, self).__init__(num_features, momentum=momentum, track_running_stats=track_running_stats)
self.weight.fast = None
self.bias.fast = None
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.zeros(num_features))
if self.feature_augment: # initialize {gamma, beta} with {0.3, 0.5}
self.gamma = torch.nn.Parameter(torch.ones(1, num_features, 1, 1)*0.3)
self.beta = torch.nn.Parameter(torch.ones(1, num_features, 1, 1)*0.5)
self.reset_parameters()
def reset_running_stats(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
def forward(self, x, step=0):
if self.weight.fast is not None and self.bias.fast is not None:
weight = self.weight.fast
bias = self.bias.fast
else:
weight = self.weight
bias = self.bias
if self.track_running_stats:
out = F.batch_norm(x, self.running_mean, self.running_var, weight, bias, training=self.training, momentum=self.momentum)
else:
out = F.batch_norm(x, torch.zeros_like(x), torch.ones_like(x), weight, bias, training=True, momentum=1)
# apply feature-wise transformation
if self.feature_augment and self.training:
gamma = (1 + torch.randn(1, self.num_features, 1, 1, dtype=self.gamma.dtype, device=self.gamma.device)*softplus(self.gamma)).expand_as(out)
beta = (torch.randn(1, self.num_features, 1, 1, dtype=self.beta.dtype, device=self.beta.device)*softplus(self.beta)).expand_as(out)
out = gamma*out + beta
return out
# --- BatchNorm2d ---
class BatchNorm2d_fw(nn.BatchNorm2d):
def __init__(self, num_features, momentum=0.1, track_running_stats=True):
super(BatchNorm2d_fw, self).__init__(num_features, momentum=momentum, track_running_stats=track_running_stats)
self.weight.fast = None
self.bias.fast = None
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.zeros(num_features))
self.reset_parameters()
def reset_running_stats(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
def forward(self, x, step=0):
if self.weight.fast is not None and self.bias.fast is not None:
weight = self.weight.fast
bias = self.bias.fast
else:
weight = self.weight
bias = self.bias
if self.track_running_stats:
out = F.batch_norm(x, self.running_mean, self.running_var, weight, bias, training=self.training, momentum=self.momentum)
else:
out = F.batch_norm(x, torch.zeros(x.size(1), dtype=x.dtype, device=x.device), torch.ones(x.size(1), dtype=x.dtype, device=x.device), weight, bias, training=True, momentum=1)
return out
# --- BatchNorm1d ---
class BatchNorm1d_fw(nn.BatchNorm1d):
def __init__(self, num_features, momentum=0.1, track_running_stats=True):
super(BatchNorm1d_fw, self).__init__(num_features, momentum=momentum, track_running_stats=track_running_stats)
self.weight.fast = None
self.bias.fast = None
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.zeros(num_features))
self.reset_parameters()
def reset_running_stats(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
def forward(self, x, step=0):
if self.weight.fast is not None and self.bias.fast is not None:
weight = self.weight.fast
bias = self.bias.fast
else:
weight = self.weight
bias = self.bias
if self.track_running_stats:
out = F.batch_norm(x, self.running_mean, self.running_var, weight, bias, training=self.training, momentum=self.momentum)
else:
out = F.batch_norm(x, torch.zeros(x.size(1), dtype=x.dtype, device=x.device), torch.ones(x.size(1), dtype=x.dtype, device=x.device), weight, bias, training=True, momentum=1)
return out
# --- Simple Conv Block ---
class ConvBlock(nn.Module):
maml = False
def __init__(self, indim, outdim, pool = True, padding = 1):
super(ConvBlock, self).__init__()
self.indim = indim
self.outdim = outdim
if self.maml:
self.C = Conv2d_fw(indim, outdim, 3, padding = padding)
self.BN = FeatureWiseTransformation2d_fw(outdim)
else:
self.C = nn.Conv2d(indim, outdim, 3, padding= padding)
self.BN = nn.BatchNorm2d(outdim)
self.relu = nn.ReLU(inplace=True)
self.parametrized_layers = [self.C, self.BN, self.relu]
if pool:
self.pool = nn.MaxPool2d(2)
self.parametrized_layers.append(self.pool)
for layer in self.parametrized_layers:
init_layer(layer)
self.trunk = nn.Sequential(*self.parametrized_layers)
def forward(self,x):
out = self.trunk(x)
return out
# --- Simple ResNet Block ---
class SimpleBlock(nn.Module):
maml = False
def __init__(self, indim, outdim, half_res, leaky=False):
super(SimpleBlock, self).__init__()
self.indim = indim
self.outdim = outdim
if self.maml:
self.C1 = Conv2d_fw(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
self.BN1 = BatchNorm2d_fw(outdim)
self.C2 = Conv2d_fw(outdim, outdim,kernel_size=3, padding=1,bias=False)
self.BN2 = FeatureWiseTransformation2d_fw(outdim) # feature-wise transformation at the end of each residual block
else:
self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
self.BN1 = nn.BatchNorm2d(outdim)
self.C2 = nn.Conv2d(outdim, outdim,kernel_size=3, padding=1,bias=False)
self.BN2 = nn.BatchNorm2d(outdim)
self.relu1 = nn.ReLU(inplace=True) if not leaky else nn.LeakyReLU(0.2, inplace=True)
self.relu2 = nn.ReLU(inplace=True) if not leaky else nn.LeakyReLU(0.2, inplace=True)
self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2]
self.half_res = half_res
# if the input number of channels is not equal to the output, then need a 1x1 convolution
if indim!=outdim:
if self.maml:
self.shortcut = Conv2d_fw(indim, outdim, 1, 2 if half_res else 1, bias=False)
self.BNshortcut = FeatureWiseTransformation2d_fw(outdim)
else:
self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False)
self.BNshortcut = nn.BatchNorm2d(outdim)
self.parametrized_layers.append(self.shortcut)
self.parametrized_layers.append(self.BNshortcut)
self.shortcut_type = '1x1'
else:
self.shortcut_type = 'identity'
for layer in self.parametrized_layers:
init_layer(layer)
def forward(self, x):
out = self.C1(x)
out = self.BN1(out)
out = self.relu1(out)
out = self.C2(out)
out = self.BN2(out)
short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x))
out = out + short_out
out = self.relu2(out)
return out
# --- ConvNet module ---
class ConvNet(nn.Module):
def __init__(self, depth, flatten = True):
super(ConvNet,self).__init__()
self.grads = []
self.fmaps = []
trunk = []
for i in range(depth):
indim = 3 if i == 0 else 64
outdim = 64
B = ConvBlock(indim, outdim, pool = ( i <4 ) ) #only pooling for fist 4 layers
trunk.append(B)
if flatten:
trunk.append(Flatten())
self.trunk = nn.Sequential(*trunk)
self.final_feat_dim = 1600
def forward(self,x):
out = self.trunk(x)
return out
# --- ConvNetNopool module ---
class ConvNetNopool(nn.Module): #Relation net use a 4 layer conv with pooling in only first two layers, else no pooling
def __init__(self, depth):
super(ConvNetNopool,self).__init__()
self.grads = []
self.fmaps = []
trunk = []
for i in range(depth):
indim = 3 if i == 0 else 64
outdim = 64
B = ConvBlock(indim, outdim, pool = ( i in [0,1] ), padding = 0 if i in[0,1] else 1 ) #only first two layer has pooling and no padding
trunk.append(B)
self.trunk = nn.Sequential(*trunk)
self.final_feat_dim = [64,19,19]
def forward(self,x):
out = self.trunk(x)
return out
# --- ResNet module ---
class ResNet(nn.Module):
maml = False
print('backbone:', 'maml:', maml)
def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten=True, leakyrelu=False):
# list_of_num_layers specifies number of layers in each stage
# list_of_out_dims specifies number of output channel for each stage
super(ResNet,self).__init__()
self.grads = []
self.fmaps = []
assert len(list_of_num_layers)==4, 'Can have only four stages'
if self.maml:
conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
bn1 = BatchNorm2d_fw(64)
else:
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
bn1 = nn.BatchNorm2d(64)
relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True)
pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
init_layer(conv1)
init_layer(bn1)
trunk = [conv1, bn1, relu, pool1]
indim = 64
for i in range(4):
for j in range(list_of_num_layers[i]):
half_res = (i>=1) and (j==0)
B = block(indim, list_of_out_dims[i], half_res, leaky=leakyrelu)
trunk.append(B)
indim = list_of_out_dims[i]
if flatten:
avgpool = nn.AvgPool2d(7)
trunk.append(avgpool)
trunk.append(Flatten())
self.final_feat_dim = indim
else:
self.final_feat_dim = [ indim, 7, 7]
self.trunk = nn.Sequential(*trunk)
def forward(self,x):
out = self.trunk(x)
return out
# --- Conv networks ---
def Conv4():
return ConvNet(4)
def Conv6():
return ConvNet(6)
def Conv4NP():
return ConvNetNopool(4)
def Conv6NP():
return ConvNetNopool(6)
# --- ResNet networks ---
def ResNet10(flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [1,1,1,1],[64,128,256,512], flatten, leakyrelu)
def ResNet18(flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [2,2,2,2],[64,128,256,512], flatten, leakyrelu)
def ResNet34(flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [3,4,6,3],[64,128,256,512], flatten, leakyrelu)
model_dict = dict(Conv4 = Conv4,
Conv6 = Conv6,
ResNet10 = ResNet10,
ResNet18 = ResNet18,
ResNet34 = ResNet34)