import torch from torch import nn import torch.nn.functional as F class BoundaryLoss(nn.Module): """ Deep Open Intent Classification with Adaptive Decision Boundary. https://arxiv.org/pdf/2012.10209.pdf """ def __init__(self, num_labels=10, feat_dim=2, device = None): super(BoundaryLoss, self).__init__() self.num_labels = num_labels self.feat_dim = feat_dim self.delta = nn.Parameter(torch.randn(num_labels).to(device)) nn.init.normal_(self.delta) def forward(self, pooled_output, centroids, labels): delta = F.softplus(self.delta) c = centroids[labels] d = delta[labels] x = pooled_output euc_dis = torch.norm(x - c,2, 1).view(-1) pos_mask = (euc_dis > d).type(torch.cuda.FloatTensor) neg_mask = (euc_dis < d).type(torch.cuda.FloatTensor) pos_loss = (euc_dis - d) * pos_mask neg_loss = (d - euc_dis) * neg_mask loss = pos_loss.mean() + neg_loss.mean() return loss, delta