| from __future__ import print_function
|
| from __future__ import division
|
| import torch
|
| import torch.nn as nn
|
| from torch.nn import Parameter
|
| import torch.distributed as dist
|
| import math
|
|
|
|
|
| def l2_norm(input, axis=1):
|
| norm = torch.norm(input, p=2, dim=axis, keepdim=True)
|
| output = torch.div(input, norm)
|
| return output
|
|
|
|
|
| def calc_logits(embeddings, kernel):
|
| """ calculate original logits
|
| """
|
| embeddings = l2_norm(embeddings, axis=1)
|
| kernel_norm = l2_norm(kernel, axis=0)
|
| cos_theta = torch.mm(embeddings, kernel_norm)
|
| cos_theta = cos_theta.clamp(-1, 1)
|
| with torch.no_grad():
|
| origin_cos = cos_theta.clone()
|
| return cos_theta, origin_cos
|
|
|
|
|
| @torch.no_grad()
|
| def all_gather_tensor(input_tensor):
|
| """ allgather tensor (difference size in 0-dim) from all workers
|
| """
|
| world_size = dist.get_world_size()
|
|
|
| tensor_size = torch.tensor([input_tensor.shape[0]], dtype=torch.int64).cuda()
|
| tensor_size_list = [torch.ones_like(tensor_size) for _ in range(world_size)]
|
| dist.all_gather(tensor_list=tensor_size_list, tensor=tensor_size, async_op=False)
|
| max_size = torch.cat(tensor_size_list, dim=0).max()
|
|
|
| padded = torch.empty(max_size.item(), *input_tensor.shape[1:], dtype=input_tensor.dtype).cuda()
|
| padded[:input_tensor.shape[0]] = input_tensor
|
| padded_list = [torch.ones_like(padded) for _ in range(world_size)]
|
| dist.all_gather(tensor_list=padded_list, tensor=padded, async_op=False)
|
|
|
| slices = []
|
| for ts, t in zip(tensor_size_list, padded_list):
|
| slices.append(t[:ts.item()])
|
| return torch.cat(slices, dim=0)
|
|
|
|
|
| def calc_top1_acc(original_logits, label,ddp=False):
|
| """
|
| Compute the top1 accuracy during training
|
| :param original_logits: logits w/o margin, [bs, C]
|
| :param label: labels [bs]
|
| :return: acc in all gpus
|
| """
|
| assert (original_logits.size()[0] == label.size()[0])
|
|
|
| with torch.no_grad():
|
| _, max_index = torch.max(original_logits, dim=1, keepdim=False)
|
| count = (max_index == label).sum()
|
| if ddp:
|
| dist.all_reduce(count, dist.ReduceOp.SUM)
|
|
|
| return count.item() / (original_logits.size()[0] * dist.get_world_size())
|
| else:
|
| return count.item() / (original_logits.size()[0])
|
|
|
| def l2_norm(input, axis=1):
|
| norm = torch.norm(input, p=2, dim=axis, keepdim=True)
|
| output = torch.div(input, norm)
|
| return output
|
|
|
|
|
| class FC_ddp2(nn.Module):
|
| """
|
| Implement of (CVPR2021 Consistent Instance False Positive Improves Fairness in Face Recognition)
|
| No model parallel is used
|
| """
|
|
|
| def __init__(self,
|
| in_features,
|
| out_features,
|
| scale=64.0,
|
| margin=0.4,
|
| mode='cosface',
|
| use_cifp=False,
|
| reduction='mean',
|
| ddp=False):
|
| """ Args:
|
| in_features: size of each input features
|
| out_features: size of each output features
|
| scale: norm of input feature
|
| margin: margin
|
| """
|
| super(FC_ddp2, self).__init__()
|
| self.in_features = in_features
|
| self.out_features = out_features
|
| self.scale = scale
|
| self.margin = margin
|
| self.mode = mode
|
| self.use_cifp = use_cifp
|
| self.kernel = Parameter(torch.Tensor(in_features, out_features))
|
| self.ddp = ddp
|
| nn.init.normal_(self.kernel, std=0.01)
|
|
|
| self.criteria = torch.nn.CrossEntropyLoss(reduction=reduction)
|
|
|
| def apply_margin(self, target_cos_theta):
|
| assert self.mode in ['cosface', 'arcface'], 'Please check the mode'
|
| if self.mode == 'arcface':
|
| cos_m = math.cos(self.margin)
|
| sin_m = math.sin(self.margin)
|
| theta = math.cos(math.pi - self.margin)
|
| sinmm = math.sin(math.pi - self.margin) * self.margin
|
| sin_theta = torch.sqrt(1.0 - torch.pow(target_cos_theta, 2))
|
| cos_theta_m = target_cos_theta * cos_m - sin_theta * sin_m
|
| target_cos_theta_m = torch.where(
|
| target_cos_theta > theta, cos_theta_m, target_cos_theta - sinmm)
|
| elif self.mode == 'cosface':
|
| target_cos_theta_m = target_cos_theta - self.margin
|
|
|
| return target_cos_theta_m
|
|
|
| def forward(self, embeddings, label, return_logits=False):
|
| """
|
|
|
| :param embeddings: local gpu [bs, 512]
|
| :param label: local labels [bs]
|
| :param return_logits: bool
|
| :return:
|
| loss: computed local loss, w/wo CIFP
|
| acc: local accuracy in one gpu
|
| output: local logits with margins, with gradients, scaled, [bs, C].
|
| """
|
| sample_num = embeddings.size(0)
|
|
|
| if not self.use_cifp:
|
| cos_theta, origin_cos = calc_logits(embeddings, self.kernel)
|
| target_cos_theta = cos_theta[torch.arange(0, sample_num), label].view(-1, 1)
|
| target_cos_theta_m = self.apply_margin(target_cos_theta)
|
| cos_theta.scatter_(1, label.view(-1, 1).long(), target_cos_theta_m)
|
| else:
|
| cos_theta, origin_cos = calc_logits(embeddings, self.kernel)
|
| cos_theta_, _ = calc_logits(embeddings, self.kernel.detach())
|
|
|
| mask = torch.zeros_like(cos_theta)
|
| mask.scatter_(1, label.view(-1, 1).long(), 1.0)
|
|
|
| tmp_cos_theta = cos_theta - 2 * mask
|
| tmp_cos_theta_ = cos_theta_ - 2 * mask
|
|
|
| target_cos_theta = cos_theta[torch.arange(0, sample_num), label].view(-1, 1)
|
| target_cos_theta_ = cos_theta_[torch.arange(0, sample_num), label].view(-1, 1)
|
|
|
| target_cos_theta_m = self.apply_margin(target_cos_theta)
|
|
|
| far = 1 / (self.out_features - 1)
|
|
|
|
|
| topk_mask = torch.greater(tmp_cos_theta, target_cos_theta)
|
| topk_sum = torch.sum(topk_mask.to(torch.int32))
|
| if self.ddp:
|
| dist.all_reduce(topk_sum)
|
| far_rank = math.ceil(far * (sample_num * (self.out_features - 1) * dist.get_world_size() - topk_sum))
|
| cos_theta_neg_topk = torch.topk((tmp_cos_theta - 2 * topk_mask.to(torch.float32)).flatten(),
|
| k=far_rank)[0]
|
| cos_theta_neg_topk = all_gather_tensor(cos_theta_neg_topk.contiguous())
|
| cos_theta_neg_th = torch.topk(cos_theta_neg_topk, k=far_rank)[0][-1]
|
|
|
| cond = torch.mul(torch.bitwise_not(topk_mask), torch.greater(tmp_cos_theta, cos_theta_neg_th))
|
| cos_theta_neg_topk = torch.mul(cond.to(torch.float32), tmp_cos_theta)
|
| cos_theta_neg_topk_ = torch.mul(cond.to(torch.float32), tmp_cos_theta_)
|
| cond = torch.greater(target_cos_theta_m, cos_theta_neg_topk)
|
|
|
| cos_theta_neg_topk = torch.where(cond, cos_theta_neg_topk, cos_theta_neg_topk_)
|
| cos_theta_neg_topk = torch.pow(cos_theta_neg_topk, 2)
|
| times = torch.sum(torch.greater(cos_theta_neg_topk, 0).to(torch.float32), dim=1, keepdim=True)
|
| times = torch.where(torch.greater(times, 0), times, torch.ones_like(times))
|
| cos_theta_neg_topk = torch.sum(cos_theta_neg_topk, dim=1, keepdim=True) / times
|
|
|
| target_cos_theta_m = target_cos_theta_m - (1 + target_cos_theta_) * cos_theta_neg_topk
|
| cos_theta.scatter_(1, label.view(-1, 1).long(), target_cos_theta_m)
|
|
|
| output = cos_theta * self.scale
|
| loss = self.criteria(output, label)
|
| acc = calc_top1_acc(origin_cos * self.scale, label,self.ddp)
|
|
|
| if return_logits:
|
| return loss, acc, output
|
|
|
| return loss, acc
|
|
|
|
|
| class FC_ddp(nn.Module):
|
| """
|
| Implement of (CVPR2021 Consistent Instance False Positive Improves Fairness in Face Recognition)
|
| No model parallel is used
|
| """
|
|
|
| def __init__(self,
|
| in_features,
|
| out_features,
|
| scale=8.0,
|
| margin=0.2,
|
| mode='cosface',
|
| use_cifp=False,
|
| reduction='mean'):
|
| """ Args:
|
| in_features: size of each input features
|
| out_features: size of each output features
|
| scale: norm of input feature
|
| margin: margin
|
| """
|
| super(FC_ddp, self).__init__()
|
| self.in_features = in_features
|
| self.out_features = out_features
|
| self.scale = scale
|
| self.margin = margin
|
| self.mode = mode
|
| self.use_cifp = use_cifp
|
|
|
|
|
|
|
| self.criteria = torch.nn.CrossEntropyLoss(reduction=reduction)
|
| self.sig = torch.nn.Sigmoid()
|
|
|
| def apply_margin(self, target_cos_theta):
|
| assert self.mode in ['cosface', 'arcface'], 'Please check the mode'
|
| if self.mode == 'arcface':
|
| cos_m = math.cos(self.margin)
|
| sin_m = math.sin(self.margin)
|
| theta = math.cos(math.pi - self.margin)
|
| sinmm = math.sin(math.pi - self.margin) * self.margin
|
| sin_theta = torch.sqrt(1.0 - torch.pow(target_cos_theta, 2))
|
| cos_theta_m = target_cos_theta * cos_m - sin_theta * sin_m
|
| target_cos_theta_m = torch.where(
|
| target_cos_theta > theta, cos_theta_m, target_cos_theta - sinmm)
|
| elif self.mode == 'cosface':
|
| target_cos_theta_m = target_cos_theta - self.margin
|
|
|
| return target_cos_theta_m
|
|
|
| def forward(self, embeddings, label, return_logits=False):
|
| """
|
|
|
| :param embeddings: local gpu [bs, 512]
|
| :param label: local labels [bs]
|
| :param return_logits: bool
|
| :return:
|
| loss: computed local loss, w/wo CIFP
|
| acc: local accuracy in one gpu
|
| output: local logits with margins, with gradients, scaled, [bs, C].
|
| """
|
| sample_num = embeddings.size(0)
|
| cos_theta = self.sig(embeddings)
|
| target_cos_theta = cos_theta[torch.arange(0, sample_num), label].view(-1, 1)
|
|
|
| target_cos_theta = target_cos_theta - self.margin
|
|
|
| out = cos_theta.clone()
|
| out.scatter_(1, label.view(-1, 1).long(), target_cos_theta)
|
|
|
| out = out * self.scale
|
|
|
| loss = self.criteria(out, label)
|
|
|
| return loss
|
|
|