import torch from torch import nn from configs.paths import DefaultPaths from models.psp.encoders.model_irse import Backbone import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import nn from torch.nn import Parameter #from IPython import embed MIN_NUM_PATCHES = 16 class Softmax(nn.Module): r"""Implement of Softmax (normal classification head): Args: in_features: size of each input sample out_features: size of each output sample device_id: the ID of GPU where the model will be trained by model parallel. if device_id=None, it will be trained on CPU without model parallel. """ def __init__(self, in_features, out_features, device_id): super(Softmax, self).__init__() self.in_features = in_features self.out_features = out_features self.device_id = device_id self.weight = Parameter(torch.FloatTensor(out_features, in_features)) self.bias = Parameter(torch.FloatTensor(out_features)) nn.init.xavier_uniform_(self.weight) nn.init.zeros_(self.bias) def forward(self, input, label): if self.device_id == None: out = F.linear(x, self.weight, self.bias) else: x = input sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0) sub_biases = torch.chunk(self.bias, len(self.device_id), dim=0) temp_x = x.cuda(self.device_id[0]) weight = sub_weights[0].cuda(self.device_id[0]) bias = sub_biases[0].cuda(self.device_id[0]) out = F.linear(temp_x, weight, bias) for i in range(1, len(self.device_id)): temp_x = x.cuda(self.device_id[i]) weight = sub_weights[i].cuda(self.device_id[i]) bias = sub_biases[i].cuda(self.device_id[i]) out = torch.cat((out, F.linear(temp_x, weight, bias).cuda(self.device_id[0])), dim=1) return out def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.BatchNorm1d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.zero_() class ArcFace(nn.Module): r"""Implement of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): Args: in_features: size of each input sample out_features: size of each output sample device_id: the ID of GPU where the model will be trained by model parallel. if device_id=None, it will be trained on CPU without model parallel. s: norm of input feature m: margin cos(theta+m) """ def __init__(self, in_features, out_features, device_id, s=64.0, m=0.50, easy_margin=False): super(ArcFace, self).__init__() self.in_features = in_features self.out_features = out_features self.device_id = device_id self.s = s self.m = m self.weight = Parameter(torch.FloatTensor(out_features, in_features)) nn.init.xavier_uniform_(self.weight) self.easy_margin = easy_margin self.cos_m = math.cos(m) self.sin_m = math.sin(m) self.th = math.cos(math.pi - m) self.mm = math.sin(math.pi - m) * m def forward(self, input, label): # --------------------------- cos(theta) & phi(theta) --------------------------- if self.device_id == None: cosine = F.linear(F.normalize(input), F.normalize(self.weight)) else: x = input sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0) temp_x = x.cuda(self.device_id[0]) weight = sub_weights[0].cuda(self.device_id[0]) cosine = F.linear(F.normalize(temp_x), F.normalize(weight)) for i in range(1, len(self.device_id)): temp_x = x.cuda(self.device_id[i]) weight = sub_weights[i].cuda(self.device_id[i]) cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1) sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) phi = cosine * self.cos_m - sine * self.sin_m if self.easy_margin: phi = torch.where(cosine > 0, phi, cosine) else: phi = torch.where(cosine > self.th, phi, cosine - self.mm) # --------------------------- convert label to one-hot --------------------------- one_hot = torch.zeros(cosine.size()) if self.device_id != None: one_hot = one_hot.cuda(self.device_id[0]) one_hot.scatter_(1, label.view(-1, 1).long(), 1) # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- output = (one_hot * phi) + ( (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 output *= self.s return output class CosFace(nn.Module): r"""Implement of CosFace (https://arxiv.org/pdf/1801.09414.pdf): Args: in_features: size of each input sample out_features: size of each output sample device_id: the ID of GPU where the model will be trained by model parallel. if device_id=None, it will be trained on CPU without model parallel. s: norm of input feature m: margin cos(theta)-m """ def __init__(self, in_features, out_features, device_id, s=64.0, m=0.35): super(CosFace, self).__init__() self.in_features = in_features self.out_features = out_features self.device_id = device_id self.s = s self.m = m print("self.device_id", self.device_id) self.weight = Parameter(torch.FloatTensor(out_features, in_features)) nn.init.xavier_uniform_(self.weight) def forward(self, input, label): # --------------------------- cos(theta) & phi(theta) --------------------------- if self.device_id == None: cosine = F.linear(F.normalize(input), F.normalize(self.weight)) else: x = input sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0) temp_x = x.cuda(self.device_id[0]) weight = sub_weights[0].cuda(self.device_id[0]) cosine = F.linear(F.normalize(temp_x), F.normalize(weight)) for i in range(1, len(self.device_id)): temp_x = x.cuda(self.device_id[i]) weight = sub_weights[i].cuda(self.device_id[i]) cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1) phi = cosine - self.m # --------------------------- convert label to one-hot --------------------------- one_hot = torch.zeros(cosine.size()) if self.device_id != None: one_hot = one_hot.cuda(self.device_id[0]) # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot one_hot.scatter_(1, label.cuda(self.device_id[0]).view(-1, 1).long(), 1) # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- output = (one_hot * phi) + ( (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 output *= self.s return output def __repr__(self): return self.__class__.__name__ + '(' \ + 'in_features = ' + str(self.in_features) \ + ', out_features = ' + str(self.out_features) \ + ', s = ' + str(self.s) \ + ', m = ' + str(self.m) + ')' class SFaceLoss(nn.Module): def __init__(self, in_features, out_features, device_id, s = 64.0, k = 80.0, a = 0.90, b = 1.2): super(SFaceLoss, self).__init__() self.in_features = in_features self.out_features = out_features self.device_id = device_id self.s = s self.k = k self.a = a self.b = b self.weight = Parameter(torch.FloatTensor(out_features, in_features)) #nn.init.xavier_uniform_(self.weight) xavier_normal_(self.weight, gain=2, mode='out') def forward(self, input, label): # --------------------------- cos(theta) & phi(theta) --------------------------- if self.device_id == None: cosine = F.linear(F.normalize(input), F.normalize(self.weight)) else: x = input sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0) temp_x = x.cuda(self.device_id[0]) weight = sub_weights[0].cuda(self.device_id[0]) cosine = F.linear(F.normalize(temp_x), F.normalize(weight)) for i in range(1, len(self.device_id)): temp_x = x.cuda(self.device_id[i]) weight = sub_weights[i].cuda(self.device_id[i]) cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1) # --------------------------- s*cos(theta) --------------------------- output = cosine * self.s # --------------------------- sface loss --------------------------- one_hot = torch.zeros(cosine.size()) if self.device_id != None: one_hot = one_hot.cuda(self.device_id[0]) one_hot.scatter_(1, label.view(-1, 1), 1) zero_hot = torch.ones(cosine.size()) if self.device_id != None: zero_hot = zero_hot.cuda(self.device_id[0]) zero_hot.scatter_(1, label.view(-1, 1), 0) WyiX = torch.sum(one_hot * output, 1) with torch.no_grad(): theta_yi = torch.acos(WyiX / self.s) weight_yi = 1.0 / (1.0 + torch.exp(-self.k * (theta_yi - self.a))) intra_loss = - weight_yi * WyiX Wj = zero_hot * output with torch.no_grad(): # theta_j = torch.acos(Wj) theta_j = torch.acos(Wj / self.s) weight_j = 1.0 / (1.0 + torch.exp(self.k * (theta_j - self.b))) inter_loss = torch.sum(weight_j * Wj, 1) loss = intra_loss.mean() + inter_loss.mean() Wyi_s = WyiX / self.s Wj_s = Wj / self.s return output, loss, intra_loss.mean(), inter_loss.mean(), Wyi_s.mean(), Wj_s.mean() class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) + x class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads self.heads = heads self.scale = dim ** -0.5 self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) def forward(self, x, mask = None): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale mask_value = -torch.finfo(dots.dtype).max #embed() if mask is not None: mask = F.pad(mask.flatten(1), (1, 0), value = True) assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' mask = mask[:, None, :] * mask[:, :, None] dots.masked_fill_(~mask, mask_value) del mask attn = dots.softmax(dim=-1) out = torch.einsum('bhij,bhjd->bhid', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') out = self.to_out(out) return out class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) ])) def forward(self, x, mask = None): for attn, ff in self.layers: x = attn(x, mask = mask) #embed() x = ff(x) return x class ViTs_face(nn.Module): def __init__(self, *, loss_type, GPU_ID, num_class, image_size, patch_size, ac_patch_size, pad, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size // patch_size) ** 2 patch_dim = channels * ac_patch_size ** 2 assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' self.patch_size = patch_size self.soft_split = nn.Unfold(kernel_size=(ac_patch_size, ac_patch_size), stride=(self.patch_size, self.patch_size), padding=(pad, pad)) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.patch_to_embedding = nn.Linear(patch_dim, dim) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.pool = pool self.to_latent = nn.Identity() self.mlp_head = nn.Sequential( nn.LayerNorm(dim), ) self.loss_type = loss_type self.GPU_ID = GPU_ID if self.loss_type == 'None': print("no loss for vit_face") else: if self.loss_type == 'Softmax': self.loss = Softmax(in_features=dim, out_features=num_class, device_id=self.GPU_ID) elif self.loss_type == 'CosFace': self.loss = CosFace(in_features=dim, out_features=num_class, device_id=self.GPU_ID) elif self.loss_type == 'ArcFace': self.loss = ArcFace(in_features=dim, out_features=num_class, device_id=self.GPU_ID) elif self.loss_type == 'SFace': self.loss = SFaceLoss(in_features=dim, out_features=num_class, device_id=self.GPU_ID) def forward(self, img, label= None , mask = None): p = self.patch_size x = self.soft_split(img).transpose(1, 2) x = self.patch_to_embedding(x) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x) x = self.transformer(x, mask) x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] x = self.to_latent(x) emb = self.mlp_head(x) if label is not None: x = self.loss(emb, label) return x, emb else: return emb def l2_norm(input, axis=1): norm = torch.norm(input, 2, axis, True) output = torch.div(input, norm) return output class IDVitLoss(nn.Module): def __init__(self): super(IDVitLoss, self).__init__() print("Loading Vit ArcFace") DEVICE = torch.device("cuda:0") NUM_CLASS = 93431 self.facenet = ViTs_face( loss_type='CosFace', GPU_ID=DEVICE, num_class=NUM_CLASS, image_size=112, patch_size=8, ac_patch_size=12, pad=4, dim=512, depth=20, heads=8, mlp_dim=2048, dropout=0.1, emb_dropout=0.1 ) self.facenet.load_state_dict(torch.load("pretrained_models/Backbone_VITs_Epoch_2_Batch_12000_Time_2021-03-17-04-05_checkpoint.pth")) self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) self.facenet = self.facenet.cuda().eval() def extract_feats(self, x): #x = x[:, :, 35:223, 32:220] # Crop interesting region x = self.face_pool(x) x_feats = self.facenet(x.cuda()) return x_feats def forward(self, y_hat, y): n_samples = y.shape[0] y_feats = self.extract_feats(y) y_hat_feats = self.extract_feats(y_hat) y_feats = y_feats.detach() loss = torch.mean((y_hat_feats - y_feats)**2) return loss * 10000