TalkingGaussian / utils /audio_utils.py
ameerazam08's picture
Upload folder using huggingface_hub
210e8a2 verified
import torch
def get_audio_features(features, att_mode, index):
if att_mode == 0:
return features[[index]]
elif att_mode == 1:
left = index - 8
pad_left = 0
if left < 0:
pad_left = -left
left = 0
auds = features[left:index]
if pad_left > 0:
# pad may be longer than auds, so do not use zeros_like
auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0)
return auds
elif att_mode == 2:
left = index - 4
right = index + 4
pad_left = 0
pad_right = 0
if left < 0:
pad_left = -left
left = 0
if right > features.shape[0]:
pad_right = right - features.shape[0]
right = features.shape[0]
auds = features[left:right]
if pad_left > 0:
auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0)
if pad_right > 0:
auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16]
return auds
else:
raise NotImplementedError(f'wrong att_mode: {att_mode}')