Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import Parameter | |
| import math | |
| class ArcNet(nn.Module): | |
| def __init__(self, | |
| feature_dim, | |
| class_dim, | |
| margin=0.2, | |
| scale=30.0, | |
| easy_margin=False): | |
| super().__init__() | |
| self.feature_dim = feature_dim | |
| self.class_dim = class_dim | |
| self.margin = margin | |
| self.scale = scale | |
| self.easy_margin = easy_margin | |
| self.weight = Parameter(torch.FloatTensor(feature_dim, class_dim)) | |
| nn.init.xavier_uniform_(self.weight) | |
| def forward(self, input, label): | |
| input_norm = torch.sqrt(torch.sum(torch.square(input), dim=1, keepdim=True)) | |
| input = torch.divide(input, input_norm) | |
| weight_norm = torch.sqrt(torch.sum(torch.square(self.weight), dim=0, keepdim=True)) | |
| weight = torch.divide(self.weight, weight_norm) | |
| cos = torch.matmul(input, weight) | |
| sin = torch.sqrt(1.0 - torch.square(cos) + 1e-6) | |
| cos_m = math.cos(self.margin) | |
| sin_m = math.sin(self.margin) | |
| phi = cos * cos_m - sin * sin_m | |
| th = math.cos(self.margin) * (-1) | |
| mm = math.sin(self.margin) * self.margin | |
| if self.easy_margin: | |
| phi = self._paddle_where_more_than(cos, 0, phi, cos) | |
| else: | |
| phi = self._paddle_where_more_than(cos, th, phi, cos - mm) | |
| one_hot = torch.nn.functional.one_hot(label, self.class_dim) | |
| one_hot = torch.squeeze(one_hot, dim=1) | |
| output = torch.multiply(one_hot, phi) + torch.multiply((1.0 - one_hot), cos) | |
| output = output * self.scale | |
| return output | |
| def _paddle_where_more_than(self, target, limit, x, y): | |
| mask = (target > limit).float() | |
| output = torch.multiply(mask, x) + torch.multiply((1.0 - mask), y) | |
| return output | |