Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from configs.paths import DefaultPaths | |
| from models.psp.encoders.model_irse import Backbone | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from torch import nn | |
| from torch.nn import Parameter | |
| #from IPython import embed | |
| MIN_NUM_PATCHES = 16 | |
| class Softmax(nn.Module): | |
| r"""Implement of Softmax (normal classification head): | |
| Args: | |
| in_features: size of each input sample | |
| out_features: size of each output sample | |
| device_id: the ID of GPU where the model will be trained by model parallel. | |
| if device_id=None, it will be trained on CPU without model parallel. | |
| """ | |
| def __init__(self, in_features, out_features, device_id): | |
| super(Softmax, self).__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.device_id = device_id | |
| self.weight = Parameter(torch.FloatTensor(out_features, in_features)) | |
| self.bias = Parameter(torch.FloatTensor(out_features)) | |
| nn.init.xavier_uniform_(self.weight) | |
| nn.init.zeros_(self.bias) | |
| def forward(self, input, label): | |
| if self.device_id == None: | |
| out = F.linear(x, self.weight, self.bias) | |
| else: | |
| x = input | |
| sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0) | |
| sub_biases = torch.chunk(self.bias, len(self.device_id), dim=0) | |
| temp_x = x.cuda(self.device_id[0]) | |
| weight = sub_weights[0].cuda(self.device_id[0]) | |
| bias = sub_biases[0].cuda(self.device_id[0]) | |
| out = F.linear(temp_x, weight, bias) | |
| for i in range(1, len(self.device_id)): | |
| temp_x = x.cuda(self.device_id[i]) | |
| weight = sub_weights[i].cuda(self.device_id[i]) | |
| bias = sub_biases[i].cuda(self.device_id[i]) | |
| out = torch.cat((out, F.linear(temp_x, weight, bias).cuda(self.device_id[0])), dim=1) | |
| return out | |
| def _initialize_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.xavier_uniform_(m.weight.data) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.BatchNorm1d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight.data) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| class ArcFace(nn.Module): | |
| r"""Implement of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): | |
| Args: | |
| in_features: size of each input sample | |
| out_features: size of each output sample | |
| device_id: the ID of GPU where the model will be trained by model parallel. | |
| if device_id=None, it will be trained on CPU without model parallel. | |
| s: norm of input feature | |
| m: margin | |
| cos(theta+m) | |
| """ | |
| def __init__(self, in_features, out_features, device_id, s=64.0, m=0.50, easy_margin=False): | |
| super(ArcFace, self).__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.device_id = device_id | |
| self.s = s | |
| self.m = m | |
| self.weight = Parameter(torch.FloatTensor(out_features, in_features)) | |
| nn.init.xavier_uniform_(self.weight) | |
| self.easy_margin = easy_margin | |
| self.cos_m = math.cos(m) | |
| self.sin_m = math.sin(m) | |
| self.th = math.cos(math.pi - m) | |
| self.mm = math.sin(math.pi - m) * m | |
| def forward(self, input, label): | |
| # --------------------------- cos(theta) & phi(theta) --------------------------- | |
| if self.device_id == None: | |
| cosine = F.linear(F.normalize(input), F.normalize(self.weight)) | |
| else: | |
| x = input | |
| sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0) | |
| temp_x = x.cuda(self.device_id[0]) | |
| weight = sub_weights[0].cuda(self.device_id[0]) | |
| cosine = F.linear(F.normalize(temp_x), F.normalize(weight)) | |
| for i in range(1, len(self.device_id)): | |
| temp_x = x.cuda(self.device_id[i]) | |
| weight = sub_weights[i].cuda(self.device_id[i]) | |
| cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), | |
| dim=1) | |
| sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) | |
| phi = cosine * self.cos_m - sine * self.sin_m | |
| if self.easy_margin: | |
| phi = torch.where(cosine > 0, phi, cosine) | |
| else: | |
| phi = torch.where(cosine > self.th, phi, cosine - self.mm) | |
| # --------------------------- convert label to one-hot --------------------------- | |
| one_hot = torch.zeros(cosine.size()) | |
| if self.device_id != None: | |
| one_hot = one_hot.cuda(self.device_id[0]) | |
| one_hot.scatter_(1, label.view(-1, 1).long(), 1) | |
| # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- | |
| output = (one_hot * phi) + ( | |
| (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 | |
| output *= self.s | |
| return output | |
| class CosFace(nn.Module): | |
| r"""Implement of CosFace (https://arxiv.org/pdf/1801.09414.pdf): | |
| Args: | |
| in_features: size of each input sample | |
| out_features: size of each output sample | |
| device_id: the ID of GPU where the model will be trained by model parallel. | |
| if device_id=None, it will be trained on CPU without model parallel. | |
| s: norm of input feature | |
| m: margin | |
| cos(theta)-m | |
| """ | |
| def __init__(self, in_features, out_features, device_id, s=64.0, m=0.35): | |
| super(CosFace, self).__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.device_id = device_id | |
| self.s = s | |
| self.m = m | |
| print("self.device_id", self.device_id) | |
| self.weight = Parameter(torch.FloatTensor(out_features, in_features)) | |
| nn.init.xavier_uniform_(self.weight) | |
| def forward(self, input, label): | |
| # --------------------------- cos(theta) & phi(theta) --------------------------- | |
| if self.device_id == None: | |
| cosine = F.linear(F.normalize(input), F.normalize(self.weight)) | |
| else: | |
| x = input | |
| sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0) | |
| temp_x = x.cuda(self.device_id[0]) | |
| weight = sub_weights[0].cuda(self.device_id[0]) | |
| cosine = F.linear(F.normalize(temp_x), F.normalize(weight)) | |
| for i in range(1, len(self.device_id)): | |
| temp_x = x.cuda(self.device_id[i]) | |
| weight = sub_weights[i].cuda(self.device_id[i]) | |
| cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), | |
| dim=1) | |
| phi = cosine - self.m | |
| # --------------------------- convert label to one-hot --------------------------- | |
| one_hot = torch.zeros(cosine.size()) | |
| if self.device_id != None: | |
| one_hot = one_hot.cuda(self.device_id[0]) | |
| # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot | |
| one_hot.scatter_(1, label.cuda(self.device_id[0]).view(-1, 1).long(), 1) | |
| # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- | |
| output = (one_hot * phi) + ( | |
| (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 | |
| output *= self.s | |
| return output | |
| def __repr__(self): | |
| return self.__class__.__name__ + '(' \ | |
| + 'in_features = ' + str(self.in_features) \ | |
| + ', out_features = ' + str(self.out_features) \ | |
| + ', s = ' + str(self.s) \ | |
| + ', m = ' + str(self.m) + ')' | |
| class SFaceLoss(nn.Module): | |
| def __init__(self, in_features, out_features, device_id, s = 64.0, k = 80.0, a = 0.90, b = 1.2): | |
| super(SFaceLoss, self).__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.device_id = device_id | |
| self.s = s | |
| self.k = k | |
| self.a = a | |
| self.b = b | |
| self.weight = Parameter(torch.FloatTensor(out_features, in_features)) | |
| #nn.init.xavier_uniform_(self.weight) | |
| xavier_normal_(self.weight, gain=2, mode='out') | |
| def forward(self, input, label): | |
| # --------------------------- cos(theta) & phi(theta) --------------------------- | |
| if self.device_id == None: | |
| cosine = F.linear(F.normalize(input), F.normalize(self.weight)) | |
| else: | |
| x = input | |
| sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0) | |
| temp_x = x.cuda(self.device_id[0]) | |
| weight = sub_weights[0].cuda(self.device_id[0]) | |
| cosine = F.linear(F.normalize(temp_x), F.normalize(weight)) | |
| for i in range(1, len(self.device_id)): | |
| temp_x = x.cuda(self.device_id[i]) | |
| weight = sub_weights[i].cuda(self.device_id[i]) | |
| cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1) | |
| # --------------------------- s*cos(theta) --------------------------- | |
| output = cosine * self.s | |
| # --------------------------- sface loss --------------------------- | |
| one_hot = torch.zeros(cosine.size()) | |
| if self.device_id != None: | |
| one_hot = one_hot.cuda(self.device_id[0]) | |
| one_hot.scatter_(1, label.view(-1, 1), 1) | |
| zero_hot = torch.ones(cosine.size()) | |
| if self.device_id != None: | |
| zero_hot = zero_hot.cuda(self.device_id[0]) | |
| zero_hot.scatter_(1, label.view(-1, 1), 0) | |
| WyiX = torch.sum(one_hot * output, 1) | |
| with torch.no_grad(): | |
| theta_yi = torch.acos(WyiX / self.s) | |
| weight_yi = 1.0 / (1.0 + torch.exp(-self.k * (theta_yi - self.a))) | |
| intra_loss = - weight_yi * WyiX | |
| Wj = zero_hot * output | |
| with torch.no_grad(): | |
| # theta_j = torch.acos(Wj) | |
| theta_j = torch.acos(Wj / self.s) | |
| weight_j = 1.0 / (1.0 + torch.exp(self.k * (theta_j - self.b))) | |
| inter_loss = torch.sum(weight_j * Wj, 1) | |
| loss = intra_loss.mean() + inter_loss.mean() | |
| Wyi_s = WyiX / self.s | |
| Wj_s = Wj / self.s | |
| return output, loss, intra_loss.mean(), inter_loss.mean(), Wyi_s.mean(), Wj_s.mean() | |
| class Residual(nn.Module): | |
| def __init__(self, fn): | |
| super().__init__() | |
| self.fn = fn | |
| def forward(self, x, **kwargs): | |
| return self.fn(x, **kwargs) + x | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.fn = fn | |
| def forward(self, x, **kwargs): | |
| return self.fn(self.norm(x), **kwargs) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout = 0.): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Attention(nn.Module): | |
| def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| self.heads = heads | |
| self.scale = dim ** -0.5 | |
| self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) | |
| self.to_out = nn.Sequential( | |
| nn.Linear(inner_dim, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x, mask = None): | |
| b, n, _, h = *x.shape, self.heads | |
| qkv = self.to_qkv(x).chunk(3, dim = -1) | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) | |
| dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale | |
| mask_value = -torch.finfo(dots.dtype).max | |
| #embed() | |
| if mask is not None: | |
| mask = F.pad(mask.flatten(1), (1, 0), value = True) | |
| assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' | |
| mask = mask[:, None, :] * mask[:, :, None] | |
| dots.masked_fill_(~mask, mask_value) | |
| del mask | |
| attn = dots.softmax(dim=-1) | |
| out = torch.einsum('bhij,bhjd->bhid', attn, v) | |
| out = rearrange(out, 'b h n d -> b n (h d)') | |
| out = self.to_out(out) | |
| return out | |
| class Transformer(nn.Module): | |
| def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): | |
| super().__init__() | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append(nn.ModuleList([ | |
| Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), | |
| Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) | |
| ])) | |
| def forward(self, x, mask = None): | |
| for attn, ff in self.layers: | |
| x = attn(x, mask = mask) | |
| #embed() | |
| x = ff(x) | |
| return x | |
| class ViTs_face(nn.Module): | |
| def __init__(self, *, loss_type, GPU_ID, num_class, image_size, patch_size, ac_patch_size, | |
| pad, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): | |
| super().__init__() | |
| assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' | |
| num_patches = (image_size // patch_size) ** 2 | |
| patch_dim = channels * ac_patch_size ** 2 | |
| assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' | |
| assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' | |
| self.patch_size = patch_size | |
| self.soft_split = nn.Unfold(kernel_size=(ac_patch_size, ac_patch_size), stride=(self.patch_size, self.patch_size), padding=(pad, pad)) | |
| self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | |
| self.patch_to_embedding = nn.Linear(patch_dim, dim) | |
| self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | |
| self.dropout = nn.Dropout(emb_dropout) | |
| self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) | |
| self.pool = pool | |
| self.to_latent = nn.Identity() | |
| self.mlp_head = nn.Sequential( | |
| nn.LayerNorm(dim), | |
| ) | |
| self.loss_type = loss_type | |
| self.GPU_ID = GPU_ID | |
| if self.loss_type == 'None': | |
| print("no loss for vit_face") | |
| else: | |
| if self.loss_type == 'Softmax': | |
| self.loss = Softmax(in_features=dim, out_features=num_class, device_id=self.GPU_ID) | |
| elif self.loss_type == 'CosFace': | |
| self.loss = CosFace(in_features=dim, out_features=num_class, device_id=self.GPU_ID) | |
| elif self.loss_type == 'ArcFace': | |
| self.loss = ArcFace(in_features=dim, out_features=num_class, device_id=self.GPU_ID) | |
| elif self.loss_type == 'SFace': | |
| self.loss = SFaceLoss(in_features=dim, out_features=num_class, device_id=self.GPU_ID) | |
| def forward(self, img, label= None , mask = None): | |
| p = self.patch_size | |
| x = self.soft_split(img).transpose(1, 2) | |
| x = self.patch_to_embedding(x) | |
| b, n, _ = x.shape | |
| cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| x += self.pos_embedding[:, :(n + 1)] | |
| x = self.dropout(x) | |
| x = self.transformer(x, mask) | |
| x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] | |
| x = self.to_latent(x) | |
| emb = self.mlp_head(x) | |
| if label is not None: | |
| x = self.loss(emb, label) | |
| return x, emb | |
| else: | |
| return emb | |
| def l2_norm(input, axis=1): | |
| norm = torch.norm(input, 2, axis, True) | |
| output = torch.div(input, norm) | |
| return output | |
| class IDVitLoss(nn.Module): | |
| def __init__(self): | |
| super(IDVitLoss, self).__init__() | |
| print("Loading Vit ArcFace") | |
| DEVICE = torch.device("cuda:0") | |
| NUM_CLASS = 93431 | |
| self.facenet = ViTs_face( | |
| loss_type='CosFace', | |
| GPU_ID=DEVICE, | |
| num_class=NUM_CLASS, | |
| image_size=112, | |
| patch_size=8, | |
| ac_patch_size=12, | |
| pad=4, | |
| dim=512, | |
| depth=20, | |
| heads=8, | |
| mlp_dim=2048, | |
| dropout=0.1, | |
| emb_dropout=0.1 | |
| ) | |
| self.facenet.load_state_dict(torch.load("pretrained_models/Backbone_VITs_Epoch_2_Batch_12000_Time_2021-03-17-04-05_checkpoint.pth")) | |
| self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) | |
| self.facenet = self.facenet.cuda().eval() | |
| def extract_feats(self, x): | |
| #x = x[:, :, 35:223, 32:220] # Crop interesting region | |
| x = self.face_pool(x) | |
| x_feats = self.facenet(x.cuda()) | |
| return x_feats | |
| def forward(self, y_hat, y): | |
| n_samples = y.shape[0] | |
| y_feats = self.extract_feats(y) | |
| y_hat_feats = self.extract_feats(y_hat) | |
| y_feats = y_feats.detach() | |
| loss = torch.mean((y_hat_feats - y_feats)**2) | |
| return loss * 10000 | |