File size: 2,356 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from torch import nn
from torch.nn import functional as F
from .sa import SALayer
from libs.utils.metric import cal_cls_acc


def align_segments_feat(segments_feat):
    dtype = segments_feat[0].dtype
    device = segments_feat[0].device
    batch_size = len(segments_feat)
    max_segment_nums = max([item.shape[1] for item in segments_feat])
    aligned_segments_feat = list()
    masks = torch.zeros([batch_size, max_segment_nums], dtype=dtype, device=device)
    
    for batch_idx in range(batch_size):
        cur_segment_nums = segments_feat[batch_idx].shape[1]
        masks[batch_idx, :cur_segment_nums] = 1
        aligned_segments_feat.append(
            F.pad(
                segments_feat[batch_idx],
                (0, max_segment_nums - cur_segment_nums, 0, 0),
                mode='constant',
                value=0
            )
        )
    aligned_segments_feat = torch.stack(aligned_segments_feat, dim=0)
    return aligned_segments_feat, masks


class HeadBodyDividePredictor(nn.Module):
    def __init__(self, in_dim, head_nums, scale=1):
        super().__init__()
        self.in_dim = in_dim
        self.scale = scale
        self.fusion_layer = SALayer(in_dim, in_dim, head_nums)
        self.classifier= nn.Conv1d(in_dim, 1, 1, 1, 0)

    def forward(self, feats, segments, divide_labels=None):
        segments = [[int(subitem * self.scale) for subitem in item] for item in segments]
        segments_feat = [feats_pi[:, segments_pi] for feats_pi, segments_pi in zip(feats, segments)]
        aligned_segments_feat, masks = align_segments_feat(segments_feat)
        aligned_segments_feat = self.fusion_layer(aligned_segments_feat, masks)
        divide_logits = self.classifier(aligned_segments_feat).squeeze(1)
        divide_logits = divide_logits - (1 - masks) * 1e8
        divide_preds = torch.argmax(divide_logits, dim=1)
        
        result_info = dict()
        ext_info = dict()
        if self.training:
            result_info['divide_loss'] = F.cross_entropy(divide_logits, divide_labels)
            correct_nums, total_nums = cal_cls_acc(divide_preds, divide_labels)
            if total_nums != 0:
                result_info['divide_acc'] = correct_nums / total_nums
        
        divide_preds = divide_preds.detach().cpu().tolist()
        return divide_preds, result_info, ext_info