| | import logging |
| | import os |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch.nn import Module |
| | from torch.nn.functional import normalize, linear |
| | from torch.nn.parameter import Parameter |
| |
|
| |
|
| | class PartialFC(Module): |
| | """ |
| | Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint, |
| | Partial FC: Training 10 Million Identities on a Single Machine |
| | See the original paper: |
| | https://arxiv.org/abs/2010.05222 |
| | """ |
| |
|
| | @torch.no_grad() |
| | def __init__(self, rank, local_rank, world_size, batch_size, resume, |
| | margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"): |
| | """ |
| | rank: int |
| | Unique process(GPU) ID from 0 to world_size - 1. |
| | local_rank: int |
| | Unique process(GPU) ID within the server from 0 to 7. |
| | world_size: int |
| | Number of GPU. |
| | batch_size: int |
| | Batch size on current rank(GPU). |
| | resume: bool |
| | Select whether to restore the weight of softmax. |
| | margin_softmax: callable |
| | A function of margin softmax, eg: cosface, arcface. |
| | num_classes: int |
| | The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size, |
| | required. |
| | sample_rate: float |
| | The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling |
| | can greatly speed up training, and reduce a lot of GPU memory, default is 1.0. |
| | embedding_size: int |
| | The feature dimension, default is 512. |
| | prefix: str |
| | Path for save checkpoint, default is './'. |
| | """ |
| | super(PartialFC, self).__init__() |
| | |
| | self.num_classes: int = num_classes |
| | self.rank: int = rank |
| | self.local_rank: int = local_rank |
| | self.device: torch.device = torch.device("cuda:{}".format(self.local_rank)) |
| | self.world_size: int = world_size |
| | self.batch_size: int = batch_size |
| | self.margin_softmax: callable = margin_softmax |
| | self.sample_rate: float = sample_rate |
| | self.embedding_size: int = embedding_size |
| | self.prefix: str = prefix |
| | self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size) |
| | self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size) |
| | self.num_sample: int = int(self.sample_rate * self.num_local) |
| |
|
| | self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank)) |
| | self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank)) |
| |
|
| | if resume: |
| | try: |
| | self.weight: torch.Tensor = torch.load(self.weight_name) |
| | self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) |
| | if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local: |
| | raise IndexError |
| | logging.info("softmax weight resume successfully!") |
| | logging.info("softmax weight mom resume successfully!") |
| | except (FileNotFoundError, KeyError, IndexError): |
| | self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) |
| | self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) |
| | logging.info("softmax weight init!") |
| | logging.info("softmax weight mom init!") |
| | else: |
| | self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) |
| | self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) |
| | logging.info("softmax weight init successfully!") |
| | logging.info("softmax weight mom init successfully!") |
| | self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank) |
| |
|
| | self.index = None |
| | if int(self.sample_rate) == 1: |
| | self.update = lambda: 0 |
| | self.sub_weight = Parameter(self.weight) |
| | self.sub_weight_mom = self.weight_mom |
| | else: |
| | self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank)) |
| |
|
| | def save_params(self): |
| | """ Save softmax weight for each rank on prefix |
| | """ |
| | torch.save(self.weight.data, self.weight_name) |
| | torch.save(self.weight_mom, self.weight_mom_name) |
| |
|
| | @torch.no_grad() |
| | def sample(self, total_label): |
| | """ |
| | Sample all positive class centers in each rank, and random select neg class centers to filling a fixed |
| | `num_sample`. |
| | |
| | total_label: tensor |
| | Label after all gather, which cross all GPUs. |
| | """ |
| | index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local) |
| | total_label[~index_positive] = -1 |
| | total_label[index_positive] -= self.class_start |
| | if int(self.sample_rate) != 1: |
| | positive = torch.unique(total_label[index_positive], sorted=True) |
| | if self.num_sample - positive.size(0) >= 0: |
| | perm = torch.rand(size=[self.num_local], device=self.device) |
| | perm[positive] = 2.0 |
| | index = torch.topk(perm, k=self.num_sample)[1] |
| | index = index.sort()[0] |
| | else: |
| | index = positive |
| | self.index = index |
| | total_label[index_positive] = torch.searchsorted(index, total_label[index_positive]) |
| | self.sub_weight = Parameter(self.weight[index]) |
| | self.sub_weight_mom = self.weight_mom[index] |
| |
|
| | def forward(self, total_features, norm_weight): |
| | """ Partial fc forward, `logits = X * sample(W)` |
| | """ |
| | torch.cuda.current_stream().wait_stream(self.stream) |
| | logits = linear(total_features, norm_weight) |
| | return logits |
| |
|
| | @torch.no_grad() |
| | def update(self): |
| | """ Set updated weight and weight_mom to memory bank. |
| | """ |
| | self.weight_mom[self.index] = self.sub_weight_mom |
| | self.weight[self.index] = self.sub_weight |
| |
|
| | def prepare(self, label, optimizer): |
| | """ |
| | get sampled class centers for cal softmax. |
| | |
| | label: tensor |
| | Label tensor on each rank. |
| | optimizer: opt |
| | Optimizer for partial fc, which need to get weight mom. |
| | """ |
| | with torch.cuda.stream(self.stream): |
| | total_label = torch.zeros( |
| | size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long) |
| | dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label) |
| | self.sample(total_label) |
| | optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None) |
| | optimizer.param_groups[-1]['params'][0] = self.sub_weight |
| | optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom |
| | norm_weight = normalize(self.sub_weight) |
| | return total_label, norm_weight |
| |
|
| | def forward_backward(self, label, features, optimizer): |
| | """ |
| | Partial fc forward and backward with model parallel |
| | |
| | label: tensor |
| | Label tensor on each rank(GPU) |
| | features: tensor |
| | Features tensor on each rank(GPU) |
| | optimizer: optimizer |
| | Optimizer for partial fc |
| | |
| | Returns: |
| | -------- |
| | x_grad: tensor |
| | The gradient of features. |
| | loss_v: tensor |
| | Loss value for cross entropy. |
| | """ |
| | total_label, norm_weight = self.prepare(label, optimizer) |
| | total_features = torch.zeros( |
| | size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) |
| | dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) |
| | total_features.requires_grad = True |
| |
|
| | logits = self.forward(total_features, norm_weight) |
| | logits = self.margin_softmax(logits, total_label) |
| |
|
| | with torch.no_grad(): |
| | max_fc = torch.max(logits, dim=1, keepdim=True)[0] |
| | dist.all_reduce(max_fc, dist.ReduceOp.MAX) |
| |
|
| | |
| | logits_exp = torch.exp(logits - max_fc) |
| | logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) |
| | dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) |
| |
|
| | |
| | logits_exp.div_(logits_sum_exp) |
| |
|
| | |
| | grad = logits_exp |
| | index = torch.where(total_label != -1)[0] |
| | one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) |
| | one_hot.scatter_(1, total_label[index, None], 1) |
| |
|
| | |
| | loss = torch.zeros(grad.size()[0], 1, device=grad.device) |
| | loss[index] = grad[index].gather(1, total_label[index, None]) |
| | dist.all_reduce(loss, dist.ReduceOp.SUM) |
| | loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) |
| |
|
| | |
| | grad[index] -= one_hot |
| | grad.div_(self.batch_size * self.world_size) |
| |
|
| | logits.backward(grad) |
| | if total_features.grad is not None: |
| | total_features.grad.detach_() |
| | x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) |
| | |
| | dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) |
| | x_grad = x_grad * self.world_size |
| | |
| | return x_grad, loss_v |
| |
|