| | from collections import defaultdict |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def make_positions(tensor, padding_idx): |
| | """Replace non-padding symbols with their position numbers. |
| | |
| | Position numbers begin at padding_idx+1. Padding symbols are ignored. |
| | """ |
| | |
| | |
| | |
| | |
| | mask = tensor.ne(padding_idx).int() |
| | return ( |
| | torch.cumsum(mask, dim=1).type_as(mask) * mask |
| | ).long() + padding_idx |
| |
|
| |
|
| | def softmax(x, dim): |
| | return F.softmax(x, dim=dim, dtype=torch.float32) |
| |
|
| |
|
| | def sequence_mask(lengths, maxlen, dtype=torch.bool): |
| | if maxlen is None: |
| | maxlen = lengths.max() |
| | mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t() |
| | mask.type(dtype) |
| | return mask |
| |
|
| |
|
| | def weights_nonzero_speech(target): |
| | |
| | |
| | dim = target.size(-1) |
| | return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) |
| |
|
| |
|
| | INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) |
| |
|
| |
|
| | def _get_full_incremental_state_key(module_instance, key): |
| | module_name = module_instance.__class__.__name__ |
| |
|
| | |
| | |
| | if not hasattr(module_instance, '_instance_id'): |
| | INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 |
| | module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] |
| |
|
| | return '{}.{}.{}'.format(module_name, module_instance._instance_id, key) |
| |
|
| |
|
| | def get_incremental_state(module, incremental_state, key): |
| | """Helper for getting incremental state for an nn.Module.""" |
| | full_key = _get_full_incremental_state_key(module, key) |
| | if incremental_state is None or full_key not in incremental_state: |
| | return None |
| | return incremental_state[full_key] |
| |
|
| |
|
| | def set_incremental_state(module, incremental_state, key, value): |
| | """Helper for setting incremental state for an nn.Module.""" |
| | if incremental_state is not None: |
| | full_key = _get_full_incremental_state_key(module, key) |
| | incremental_state[full_key] = value |
| |
|
| |
|
| | def fill_with_neg_inf(t): |
| | """FP16-compatible function that fills a tensor with -inf.""" |
| | return t.float().fill_(float('-inf')).type_as(t) |
| |
|
| |
|
| | def fill_with_neg_inf2(t): |
| | """FP16-compatible function that fills a tensor with -inf.""" |
| | return t.float().fill_(-1e8).type_as(t) |
| |
|
| |
|
| | def select_attn(attn_logits, type='best'): |
| | """ |
| | |
| | :param attn_logits: [n_layers, B, n_head, T_sp, T_txt] |
| | :return: |
| | """ |
| | encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2) |
| | |
| | encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1) |
| | if type == 'best': |
| | indices = encdec_attn.max(-1).values.sum(-1).argmax(0) |
| | encdec_attn = encdec_attn.gather( |
| | 0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0] |
| | return encdec_attn |
| | elif type == 'mean': |
| | return encdec_attn.mean(0) |
| |
|
| |
|
| | def make_pad_mask(lengths, xs=None, length_dim=-1): |
| | """Make mask tensor containing indices of padded part. |
| | Args: |
| | lengths (LongTensor or List): Batch of lengths (B,). |
| | xs (Tensor, optional): The reference tensor. |
| | If set, masks will be the same shape as this tensor. |
| | length_dim (int, optional): Dimension indicator of the above tensor. |
| | See the example. |
| | Returns: |
| | Tensor: Mask tensor containing indices of padded part. |
| | dtype=torch.uint8 in PyTorch 1.2- |
| | dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
| | Examples: |
| | With only lengths. |
| | >>> lengths = [5, 3, 2] |
| | >>> make_non_pad_mask(lengths) |
| | masks = [[0, 0, 0, 0 ,0], |
| | [0, 0, 0, 1, 1], |
| | [0, 0, 1, 1, 1]] |
| | With the reference tensor. |
| | >>> xs = torch.zeros((3, 2, 4)) |
| | >>> make_pad_mask(lengths, xs) |
| | tensor([[[0, 0, 0, 0], |
| | [0, 0, 0, 0]], |
| | [[0, 0, 0, 1], |
| | [0, 0, 0, 1]], |
| | [[0, 0, 1, 1], |
| | [0, 0, 1, 1]]], dtype=torch.uint8) |
| | >>> xs = torch.zeros((3, 2, 6)) |
| | >>> make_pad_mask(lengths, xs) |
| | tensor([[[0, 0, 0, 0, 0, 1], |
| | [0, 0, 0, 0, 0, 1]], |
| | [[0, 0, 0, 1, 1, 1], |
| | [0, 0, 0, 1, 1, 1]], |
| | [[0, 0, 1, 1, 1, 1], |
| | [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) |
| | With the reference tensor and dimension indicator. |
| | >>> xs = torch.zeros((3, 6, 6)) |
| | >>> make_pad_mask(lengths, xs, 1) |
| | tensor([[[0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0], |
| | [1, 1, 1, 1, 1, 1]], |
| | [[0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0], |
| | [1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1]], |
| | [[0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0], |
| | [1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) |
| | >>> make_pad_mask(lengths, xs, 2) |
| | tensor([[[0, 0, 0, 0, 0, 1], |
| | [0, 0, 0, 0, 0, 1], |
| | [0, 0, 0, 0, 0, 1], |
| | [0, 0, 0, 0, 0, 1], |
| | [0, 0, 0, 0, 0, 1], |
| | [0, 0, 0, 0, 0, 1]], |
| | [[0, 0, 0, 1, 1, 1], |
| | [0, 0, 0, 1, 1, 1], |
| | [0, 0, 0, 1, 1, 1], |
| | [0, 0, 0, 1, 1, 1], |
| | [0, 0, 0, 1, 1, 1], |
| | [0, 0, 0, 1, 1, 1]], |
| | [[0, 0, 1, 1, 1, 1], |
| | [0, 0, 1, 1, 1, 1], |
| | [0, 0, 1, 1, 1, 1], |
| | [0, 0, 1, 1, 1, 1], |
| | [0, 0, 1, 1, 1, 1], |
| | [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) |
| | """ |
| | if length_dim == 0: |
| | raise ValueError("length_dim cannot be 0: {}".format(length_dim)) |
| |
|
| | if not isinstance(lengths, list): |
| | lengths = lengths.tolist() |
| | bs = int(len(lengths)) |
| | if xs is None: |
| | maxlen = int(max(lengths)) |
| | else: |
| | maxlen = xs.size(length_dim) |
| |
|
| | seq_range = torch.arange(0, maxlen, dtype=torch.int64) |
| | seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) |
| | seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) |
| | mask = seq_range_expand >= seq_length_expand |
| |
|
| | if xs is not None: |
| | assert xs.size(0) == bs, (xs.size(0), bs) |
| |
|
| | if length_dim < 0: |
| | length_dim = xs.dim() + length_dim |
| | |
| | ind = tuple( |
| | slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) |
| | ) |
| | mask = mask[ind].expand_as(xs).to(xs.device) |
| | return mask |
| |
|
| |
|
| | def make_non_pad_mask(lengths, xs=None, length_dim=-1): |
| | """Make mask tensor containing indices of non-padded part. |
| | Args: |
| | lengths (LongTensor or List): Batch of lengths (B,). |
| | xs (Tensor, optional): The reference tensor. |
| | If set, masks will be the same shape as this tensor. |
| | length_dim (int, optional): Dimension indicator of the above tensor. |
| | See the example. |
| | Returns: |
| | ByteTensor: mask tensor containing indices of padded part. |
| | dtype=torch.uint8 in PyTorch 1.2- |
| | dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
| | Examples: |
| | With only lengths. |
| | >>> lengths = [5, 3, 2] |
| | >>> make_non_pad_mask(lengths) |
| | masks = [[1, 1, 1, 1 ,1], |
| | [1, 1, 1, 0, 0], |
| | [1, 1, 0, 0, 0]] |
| | With the reference tensor. |
| | >>> xs = torch.zeros((3, 2, 4)) |
| | >>> make_non_pad_mask(lengths, xs) |
| | tensor([[[1, 1, 1, 1], |
| | [1, 1, 1, 1]], |
| | [[1, 1, 1, 0], |
| | [1, 1, 1, 0]], |
| | [[1, 1, 0, 0], |
| | [1, 1, 0, 0]]], dtype=torch.uint8) |
| | >>> xs = torch.zeros((3, 2, 6)) |
| | >>> make_non_pad_mask(lengths, xs) |
| | tensor([[[1, 1, 1, 1, 1, 0], |
| | [1, 1, 1, 1, 1, 0]], |
| | [[1, 1, 1, 0, 0, 0], |
| | [1, 1, 1, 0, 0, 0]], |
| | [[1, 1, 0, 0, 0, 0], |
| | [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) |
| | With the reference tensor and dimension indicator. |
| | >>> xs = torch.zeros((3, 6, 6)) |
| | >>> make_non_pad_mask(lengths, xs, 1) |
| | tensor([[[1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1], |
| | [0, 0, 0, 0, 0, 0]], |
| | [[1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1], |
| | [0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0]], |
| | [[1, 1, 1, 1, 1, 1], |
| | [1, 1, 1, 1, 1, 1], |
| | [0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0], |
| | [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) |
| | >>> make_non_pad_mask(lengths, xs, 2) |
| | tensor([[[1, 1, 1, 1, 1, 0], |
| | [1, 1, 1, 1, 1, 0], |
| | [1, 1, 1, 1, 1, 0], |
| | [1, 1, 1, 1, 1, 0], |
| | [1, 1, 1, 1, 1, 0], |
| | [1, 1, 1, 1, 1, 0]], |
| | [[1, 1, 1, 0, 0, 0], |
| | [1, 1, 1, 0, 0, 0], |
| | [1, 1, 1, 0, 0, 0], |
| | [1, 1, 1, 0, 0, 0], |
| | [1, 1, 1, 0, 0, 0], |
| | [1, 1, 1, 0, 0, 0]], |
| | [[1, 1, 0, 0, 0, 0], |
| | [1, 1, 0, 0, 0, 0], |
| | [1, 1, 0, 0, 0, 0], |
| | [1, 1, 0, 0, 0, 0], |
| | [1, 1, 0, 0, 0, 0], |
| | [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) |
| | """ |
| | return ~make_pad_mask(lengths, xs, length_dim) |
| |
|
| |
|
| | def get_mask_from_lengths(lengths): |
| | max_len = torch.max(lengths).item() |
| | ids = torch.arange(0, max_len).to(lengths.device) |
| | mask = (ids < lengths.unsqueeze(1)).bool() |
| | return mask |
| |
|
| |
|
| | def group_hidden_by_segs(h, seg_ids, max_len): |
| | """ |
| | |
| | :param h: [B, T, H] |
| | :param seg_ids: [B, T] |
| | :return: h_ph: [B, T_ph, H] |
| | """ |
| | B, T, H = h.shape |
| | h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h) |
| | all_ones = h.new_ones(h.shape[:2]) |
| | cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous() |
| | h_gby_segs = h_gby_segs[:, 1:] |
| | cnt_gby_segs = cnt_gby_segs[:, 1:] |
| | h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1) |
| | return h_gby_segs, cnt_gby_segs |
| |
|
| | def expand_word2ph(word_encoding, ph2word): |
| | word_encoding = F.pad(word_encoding,[0,0,1,0]) |
| | ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]]) |
| | out = torch.gather(word_encoding, 1, ph2word_) |
| | return out |
| |
|