Smile_Changer / criteria /id_vit_loss.py
LogicGoInfotechSpaces's picture
Bundle StyleFeatureEditor code packages in Space to fix ModuleNotFoundError
95b1715
import torch
from torch import nn
from configs.paths import DefaultPaths
from models.psp.encoders.model_irse import Backbone
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from torch.nn import Parameter
#from IPython import embed
MIN_NUM_PATCHES = 16
class Softmax(nn.Module):
r"""Implement of Softmax (normal classification head):
Args:
in_features: size of each input sample
out_features: size of each output sample
device_id: the ID of GPU where the model will be trained by model parallel.
if device_id=None, it will be trained on CPU without model parallel.
"""
def __init__(self, in_features, out_features, device_id):
super(Softmax, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.device_id = device_id
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
self.bias = Parameter(torch.FloatTensor(out_features))
nn.init.xavier_uniform_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input, label):
if self.device_id == None:
out = F.linear(x, self.weight, self.bias)
else:
x = input
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
sub_biases = torch.chunk(self.bias, len(self.device_id), dim=0)
temp_x = x.cuda(self.device_id[0])
weight = sub_weights[0].cuda(self.device_id[0])
bias = sub_biases[0].cuda(self.device_id[0])
out = F.linear(temp_x, weight, bias)
for i in range(1, len(self.device_id)):
temp_x = x.cuda(self.device_id[i])
weight = sub_weights[i].cuda(self.device_id[i])
bias = sub_biases[i].cuda(self.device_id[i])
out = torch.cat((out, F.linear(temp_x, weight, bias).cuda(self.device_id[0])), dim=1)
return out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
class ArcFace(nn.Module):
r"""Implement of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
Args:
in_features: size of each input sample
out_features: size of each output sample
device_id: the ID of GPU where the model will be trained by model parallel.
if device_id=None, it will be trained on CPU without model parallel.
s: norm of input feature
m: margin
cos(theta+m)
"""
def __init__(self, in_features, out_features, device_id, s=64.0, m=0.50, easy_margin=False):
super(ArcFace, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.device_id = device_id
self.s = s
self.m = m
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def forward(self, input, label):
# --------------------------- cos(theta) & phi(theta) ---------------------------
if self.device_id == None:
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
else:
x = input
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
temp_x = x.cuda(self.device_id[0])
weight = sub_weights[0].cuda(self.device_id[0])
cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
for i in range(1, len(self.device_id)):
temp_x = x.cuda(self.device_id[i])
weight = sub_weights[i].cuda(self.device_id[i])
cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])),
dim=1)
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------------
one_hot = torch.zeros(cosine.size())
if self.device_id != None:
one_hot = one_hot.cuda(self.device_id[0])
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + (
(1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
output *= self.s
return output
class CosFace(nn.Module):
r"""Implement of CosFace (https://arxiv.org/pdf/1801.09414.pdf):
Args:
in_features: size of each input sample
out_features: size of each output sample
device_id: the ID of GPU where the model will be trained by model parallel.
if device_id=None, it will be trained on CPU without model parallel.
s: norm of input feature
m: margin
cos(theta)-m
"""
def __init__(self, in_features, out_features, device_id, s=64.0, m=0.35):
super(CosFace, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.device_id = device_id
self.s = s
self.m = m
print("self.device_id", self.device_id)
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
def forward(self, input, label):
# --------------------------- cos(theta) & phi(theta) ---------------------------
if self.device_id == None:
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
else:
x = input
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
temp_x = x.cuda(self.device_id[0])
weight = sub_weights[0].cuda(self.device_id[0])
cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
for i in range(1, len(self.device_id)):
temp_x = x.cuda(self.device_id[i])
weight = sub_weights[i].cuda(self.device_id[i])
cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])),
dim=1)
phi = cosine - self.m
# --------------------------- convert label to one-hot ---------------------------
one_hot = torch.zeros(cosine.size())
if self.device_id != None:
one_hot = one_hot.cuda(self.device_id[0])
# one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
one_hot.scatter_(1, label.cuda(self.device_id[0]).view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + (
(1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
output *= self.s
return output
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_features = ' + str(self.in_features) \
+ ', out_features = ' + str(self.out_features) \
+ ', s = ' + str(self.s) \
+ ', m = ' + str(self.m) + ')'
class SFaceLoss(nn.Module):
def __init__(self, in_features, out_features, device_id, s = 64.0, k = 80.0, a = 0.90, b = 1.2):
super(SFaceLoss, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.device_id = device_id
self.s = s
self.k = k
self.a = a
self.b = b
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
#nn.init.xavier_uniform_(self.weight)
xavier_normal_(self.weight, gain=2, mode='out')
def forward(self, input, label):
# --------------------------- cos(theta) & phi(theta) ---------------------------
if self.device_id == None:
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
else:
x = input
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
temp_x = x.cuda(self.device_id[0])
weight = sub_weights[0].cuda(self.device_id[0])
cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
for i in range(1, len(self.device_id)):
temp_x = x.cuda(self.device_id[i])
weight = sub_weights[i].cuda(self.device_id[i])
cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1)
# --------------------------- s*cos(theta) ---------------------------
output = cosine * self.s
# --------------------------- sface loss ---------------------------
one_hot = torch.zeros(cosine.size())
if self.device_id != None:
one_hot = one_hot.cuda(self.device_id[0])
one_hot.scatter_(1, label.view(-1, 1), 1)
zero_hot = torch.ones(cosine.size())
if self.device_id != None:
zero_hot = zero_hot.cuda(self.device_id[0])
zero_hot.scatter_(1, label.view(-1, 1), 0)
WyiX = torch.sum(one_hot * output, 1)
with torch.no_grad():
theta_yi = torch.acos(WyiX / self.s)
weight_yi = 1.0 / (1.0 + torch.exp(-self.k * (theta_yi - self.a)))
intra_loss = - weight_yi * WyiX
Wj = zero_hot * output
with torch.no_grad():
# theta_j = torch.acos(Wj)
theta_j = torch.acos(Wj / self.s)
weight_j = 1.0 / (1.0 + torch.exp(self.k * (theta_j - self.b)))
inter_loss = torch.sum(weight_j * Wj, 1)
loss = intra_loss.mean() + inter_loss.mean()
Wyi_s = WyiX / self.s
Wj_s = Wj / self.s
return output, loss, intra_loss.mean(), inter_loss.mean(), Wyi_s.mean(), Wj_s.mean()
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max
#embed()
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, mask_value)
del mask
attn = dots.softmax(dim=-1)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
#embed()
x = ff(x)
return x
class ViTs_face(nn.Module):
def __init__(self, *, loss_type, GPU_ID, num_class, image_size, patch_size, ac_patch_size,
pad, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * ac_patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.patch_size = patch_size
self.soft_split = nn.Unfold(kernel_size=(ac_patch_size, ac_patch_size), stride=(self.patch_size, self.patch_size), padding=(pad, pad))
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
)
self.loss_type = loss_type
self.GPU_ID = GPU_ID
if self.loss_type == 'None':
print("no loss for vit_face")
else:
if self.loss_type == 'Softmax':
self.loss = Softmax(in_features=dim, out_features=num_class, device_id=self.GPU_ID)
elif self.loss_type == 'CosFace':
self.loss = CosFace(in_features=dim, out_features=num_class, device_id=self.GPU_ID)
elif self.loss_type == 'ArcFace':
self.loss = ArcFace(in_features=dim, out_features=num_class, device_id=self.GPU_ID)
elif self.loss_type == 'SFace':
self.loss = SFaceLoss(in_features=dim, out_features=num_class, device_id=self.GPU_ID)
def forward(self, img, label= None , mask = None):
p = self.patch_size
x = self.soft_split(img).transpose(1, 2)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x, mask)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
emb = self.mlp_head(x)
if label is not None:
x = self.loss(emb, label)
return x, emb
else:
return emb
def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
class IDVitLoss(nn.Module):
def __init__(self):
super(IDVitLoss, self).__init__()
print("Loading Vit ArcFace")
DEVICE = torch.device("cuda:0")
NUM_CLASS = 93431
self.facenet = ViTs_face(
loss_type='CosFace',
GPU_ID=DEVICE,
num_class=NUM_CLASS,
image_size=112,
patch_size=8,
ac_patch_size=12,
pad=4,
dim=512,
depth=20,
heads=8,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1
)
self.facenet.load_state_dict(torch.load("pretrained_models/Backbone_VITs_Epoch_2_Batch_12000_Time_2021-03-17-04-05_checkpoint.pth"))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet = self.facenet.cuda().eval()
def extract_feats(self, x):
#x = x[:, :, 35:223, 32:220] # Crop interesting region
x = self.face_pool(x)
x_feats = self.facenet(x.cuda())
return x_feats
def forward(self, y_hat, y):
n_samples = y.shape[0]
y_feats = self.extract_feats(y)
y_hat_feats = self.extract_feats(y_hat)
y_feats = y_feats.detach()
loss = torch.mean((y_hat_feats - y_feats)**2)
return loss * 10000