| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from .sa import SALayer | |
| from libs.utils.metric import cal_cls_acc | |
| def align_segments_feat(segments_feat): | |
| dtype = segments_feat[0].dtype | |
| device = segments_feat[0].device | |
| batch_size = len(segments_feat) | |
| max_segment_nums = max([item.shape[1] for item in segments_feat]) | |
| aligned_segments_feat = list() | |
| masks = torch.zeros([batch_size, max_segment_nums], dtype=dtype, device=device) | |
| for batch_idx in range(batch_size): | |
| cur_segment_nums = segments_feat[batch_idx].shape[1] | |
| masks[batch_idx, :cur_segment_nums] = 1 | |
| aligned_segments_feat.append( | |
| F.pad( | |
| segments_feat[batch_idx], | |
| (0, max_segment_nums - cur_segment_nums, 0, 0), | |
| mode='constant', | |
| value=0 | |
| ) | |
| ) | |
| aligned_segments_feat = torch.stack(aligned_segments_feat, dim=0) | |
| return aligned_segments_feat, masks | |
| class HeadBodyDividePredictor(nn.Module): | |
| def __init__(self, in_dim, head_nums, scale=1): | |
| super().__init__() | |
| self.in_dim = in_dim | |
| self.scale = scale | |
| self.fusion_layer = SALayer(in_dim, in_dim, head_nums) | |
| self.classifier= nn.Conv1d(in_dim, 1, 1, 1, 0) | |
| def forward(self, feats, segments, divide_labels=None): | |
| segments = [[int(subitem * self.scale) for subitem in item] for item in segments] | |
| segments_feat = [feats_pi[:, segments_pi] for feats_pi, segments_pi in zip(feats, segments)] | |
| aligned_segments_feat, masks = align_segments_feat(segments_feat) | |
| aligned_segments_feat = self.fusion_layer(aligned_segments_feat, masks) | |
| divide_logits = self.classifier(aligned_segments_feat).squeeze(1) | |
| divide_logits = divide_logits - (1 - masks) * 1e8 | |
| divide_preds = torch.argmax(divide_logits, dim=1) | |
| result_info = dict() | |
| ext_info = dict() | |
| if self.training: | |
| result_info['divide_loss'] = F.cross_entropy(divide_logits, divide_labels) | |
| correct_nums, total_nums = cal_cls_acc(divide_preds, divide_labels) | |
| if total_nums != 0: | |
| result_info['divide_acc'] = correct_nums / total_nums | |
| divide_preds = divide_preds.detach().cpu().tolist() | |
| return divide_preds, result_info, ext_info | |