| | import math |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence |
| |
|
| |
|
| | def sort_pack_padded_sequence(input, lengths): |
| | sorted_lengths, indices = torch.sort(lengths, descending=True) |
| | tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) |
| | inv_ix = indices.clone() |
| | inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) |
| | return tmp, inv_ix |
| |
|
| | def pad_unsort_packed_sequence(input, inv_ix): |
| | tmp, _ = pad_packed_sequence(input, batch_first=True) |
| | tmp = tmp[inv_ix] |
| | return tmp |
| |
|
| | def pack_wrapper(module, attn_feats, attn_feat_lens): |
| | packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens) |
| | if isinstance(module, torch.nn.RNNBase): |
| | return pad_unsort_packed_sequence(module(packed)[0], inv_ix) |
| | else: |
| | return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) |
| |
|
| | def generate_length_mask(lens, max_length=None): |
| | lens = torch.as_tensor(lens) |
| | N = lens.size(0) |
| | if max_length is None: |
| | max_length = max(lens) |
| | idxs = torch.arange(max_length).repeat(N).view(N, max_length) |
| | idxs = idxs.to(lens.device) |
| | mask = (idxs < lens.view(-1, 1)) |
| | return mask |
| |
|
| | def mean_with_lens(features, lens): |
| | """ |
| | features: [N, T, ...] (assume the second dimension represents length) |
| | lens: [N,] |
| | """ |
| | lens = torch.as_tensor(lens) |
| | if max(lens) != features.size(1): |
| | max_length = features.size(1) |
| | mask = generate_length_mask(lens, max_length) |
| | else: |
| | mask = generate_length_mask(lens) |
| | mask = mask.to(features.device) |
| |
|
| | while mask.ndim < features.ndim: |
| | mask = mask.unsqueeze(-1) |
| | feature_mean = features * mask |
| | feature_mean = feature_mean.sum(1) |
| | while lens.ndim < feature_mean.ndim: |
| | lens = lens.unsqueeze(1) |
| | feature_mean = feature_mean / lens.to(features.device) |
| | |
| | |
| | return feature_mean |
| |
|
| | def max_with_lens(features, lens): |
| | """ |
| | features: [N, T, ...] (assume the second dimension represents length) |
| | lens: [N,] |
| | """ |
| | lens = torch.as_tensor(lens) |
| | mask = generate_length_mask(lens).to(features.device) |
| |
|
| | feature_max = features.clone() |
| | feature_max[~mask] = float("-inf") |
| | feature_max, _ = feature_max.max(1) |
| | return feature_max |
| |
|
| | def repeat_tensor(x, n): |
| | return x.unsqueeze(0).repeat(n, *([1] * len(x.shape))) |
| |
|
| | def init(m, method="kaiming"): |
| | if isinstance(m, (nn.Conv2d, nn.Conv1d)): |
| | if method == "kaiming": |
| | nn.init.kaiming_uniform_(m.weight) |
| | elif method == "xavier": |
| | nn.init.xavier_uniform_(m.weight) |
| | else: |
| | raise Exception(f"initialization method {method} not supported") |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): |
| | nn.init.constant_(m.weight, 1) |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.Linear): |
| | if method == "kaiming": |
| | nn.init.kaiming_uniform_(m.weight) |
| | elif method == "xavier": |
| | nn.init.xavier_uniform_(m.weight) |
| | else: |
| | raise Exception(f"initialization method {method} not supported") |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.Embedding): |
| | if method == "kaiming": |
| | nn.init.kaiming_uniform_(m.weight) |
| | elif method == "xavier": |
| | nn.init.xavier_uniform_(m.weight) |
| | else: |
| | raise Exception(f"initialization method {method} not supported") |
| |
|
| |
|
| |
|
| |
|
| | class PositionalEncoding(nn.Module): |
| |
|
| | def __init__(self, d_model, dropout=0.1, max_len=100): |
| | super(PositionalEncoding, self).__init__() |
| | self.dropout = nn.Dropout(p=dropout) |
| |
|
| | pe = torch.zeros(max_len, d_model) |
| | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| | div_term = torch.exp(torch.arange(0, d_model, 2).float() * \ |
| | (-math.log(10000.0) / d_model)) |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | pe = pe.unsqueeze(0).transpose(0, 1) |
| | |
| | self.register_parameter("pe", nn.Parameter(pe, requires_grad=False)) |
| |
|
| | def forward(self, x): |
| | |
| | x = x + self.pe[:x.size(0), :] |
| | return self.dropout(x) |
| |
|