import torch def get_audio_features(features, att_mode, index): if att_mode == 0: return features[[index]] elif att_mode == 1: left = index - 8 pad_left = 0 if left < 0: pad_left = -left left = 0 auds = features[left:index] if pad_left > 0: # pad may be longer than auds, so do not use zeros_like auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0) return auds elif att_mode == 2: left = index - 4 right = index + 4 pad_left = 0 pad_right = 0 if left < 0: pad_left = -left left = 0 if right > features.shape[0]: pad_right = right - features.shape[0] right = features.shape[0] auds = features[left:right] if pad_left > 0: auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0) if pad_right > 0: auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16] return auds else: raise NotImplementedError(f'wrong att_mode: {att_mode}')