| import torch | |
| def get_audio_features(features, att_mode, index): | |
| if att_mode == 0: | |
| return features[[index]] | |
| elif att_mode == 1: | |
| left = index - 8 | |
| pad_left = 0 | |
| if left < 0: | |
| pad_left = -left | |
| left = 0 | |
| auds = features[left:index] | |
| if pad_left > 0: | |
| # pad may be longer than auds, so do not use zeros_like | |
| auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0) | |
| return auds | |
| elif att_mode == 2: | |
| left = index - 4 | |
| right = index + 4 | |
| pad_left = 0 | |
| pad_right = 0 | |
| if left < 0: | |
| pad_left = -left | |
| left = 0 | |
| if right > features.shape[0]: | |
| pad_right = right - features.shape[0] | |
| right = features.shape[0] | |
| auds = features[left:right] | |
| if pad_left > 0: | |
| auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0) | |
| if pad_right > 0: | |
| auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16] | |
| return auds | |
| else: | |
| raise NotImplementedError(f'wrong att_mode: {att_mode}') |