| import random |
| from collections import defaultdict |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .abstract_loss_func import AbstractLossClass |
| from metrics.registry import LOSSFUNC |
|
|
|
|
| def swap_spe_features(type_list, value_list): |
| type_list = type_list.cpu().numpy().tolist() |
| |
| index_list = list(range(len(type_list))) |
|
|
| |
| spe_dict = defaultdict(list) |
|
|
| |
| for i, one_type in enumerate(type_list): |
| spe_dict[one_type].append(index_list[i]) |
|
|
| |
| for keys in spe_dict.keys(): |
| random.shuffle(spe_dict[keys]) |
| |
| |
| new_index_list = [] |
| for one_type in type_list: |
| value = spe_dict[one_type].pop() |
| new_index_list.append(value) |
|
|
| |
| value_list_new = value_list[new_index_list] |
|
|
| return value_list_new |
|
|
|
|
| @LOSSFUNC.register_module(module_name="contrastive_regularization") |
| class ContrastiveLoss(AbstractLossClass): |
| def __init__(self, margin=1.0): |
| super().__init__() |
| self.margin = margin |
| |
| def contrastive_loss(self, anchor, positive, negative): |
| dist_pos = F.pairwise_distance(anchor, positive) |
| dist_neg = F.pairwise_distance(anchor, negative) |
| |
| loss = torch.mean(torch.clamp(dist_pos - dist_neg + self.margin, min=0.0)) |
| return loss |
|
|
| def forward(self, common, specific, spe_label): |
| |
| bs = common.shape[0] |
| real_common, fake_common = common.chunk(2) |
| |
| idx_list = list(range(0, bs//2)) |
| random.shuffle(idx_list) |
| real_common_anchor = common[idx_list] |
| |
| idx_list = list(range(bs//2, bs)) |
| random.shuffle(idx_list) |
| fake_common_anchor = common[idx_list] |
| |
| specific_anchor = swap_spe_features(spe_label, specific) |
| real_specific_anchor, fake_specific_anchor = specific_anchor.chunk(2) |
| real_specific, fake_specific = specific.chunk(2) |
|
|
| |
| loss_realcommon = self.contrastive_loss(real_common, real_common_anchor, fake_common_anchor) |
| loss_fakecommon = self.contrastive_loss(fake_common, fake_common_anchor, real_common_anchor) |
|
|
| |
| loss_realspecific = self.contrastive_loss(real_specific, real_specific_anchor, fake_specific_anchor) |
| loss_fakespecific = self.contrastive_loss(fake_specific, fake_specific_anchor, real_specific_anchor) |
|
|
| |
| loss = loss_realcommon + loss_fakecommon + loss_fakespecific + loss_realspecific |
| return loss |