training_sem / libs /model /divide_predictor.py
kai-2054's picture
Initial commit: add code
cb0ad2d
import torch
from torch import nn
from torch.nn import functional as F
from .sa import SALayer
from libs.utils.metric import cal_cls_acc
def align_segments_feat(segments_feat):
dtype = segments_feat[0].dtype
device = segments_feat[0].device
batch_size = len(segments_feat)
max_segment_nums = max([item.shape[1] for item in segments_feat])
aligned_segments_feat = list()
masks = torch.zeros([batch_size, max_segment_nums], dtype=dtype, device=device)
for batch_idx in range(batch_size):
cur_segment_nums = segments_feat[batch_idx].shape[1]
masks[batch_idx, :cur_segment_nums] = 1
aligned_segments_feat.append(
F.pad(
segments_feat[batch_idx],
(0, max_segment_nums - cur_segment_nums, 0, 0),
mode='constant',
value=0
)
)
aligned_segments_feat = torch.stack(aligned_segments_feat, dim=0)
return aligned_segments_feat, masks
class HeadBodyDividePredictor(nn.Module):
def __init__(self, in_dim, head_nums, scale=1):
super().__init__()
self.in_dim = in_dim
self.scale = scale
self.fusion_layer = SALayer(in_dim, in_dim, head_nums)
self.classifier= nn.Conv1d(in_dim, 1, 1, 1, 0)
def forward(self, feats, segments, divide_labels=None):
segments = [[int(subitem * self.scale) for subitem in item] for item in segments]
segments_feat = [feats_pi[:, segments_pi] for feats_pi, segments_pi in zip(feats, segments)]
aligned_segments_feat, masks = align_segments_feat(segments_feat)
aligned_segments_feat = self.fusion_layer(aligned_segments_feat, masks)
divide_logits = self.classifier(aligned_segments_feat).squeeze(1)
divide_logits = divide_logits - (1 - masks) * 1e8
divide_preds = torch.argmax(divide_logits, dim=1)
result_info = dict()
ext_info = dict()
if self.training:
result_info['divide_loss'] = F.cross_entropy(divide_logits, divide_labels)
correct_nums, total_nums = cal_cls_acc(divide_preds, divide_labels)
if total_nums != 0:
result_info['divide_acc'] = correct_nums / total_nums
divide_preds = divide_preds.detach().cpu().tolist()
return divide_preds, result_info, ext_info