|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def build_word_mask(x2word, y2word): |
|
|
return (x2word[:, :, None] == y2word[:, None, :]).long() |
|
|
|
|
|
|
|
|
def mel2ph_to_mel2word(mel2ph, ph2word): |
|
|
mel2word = (ph2word - 1).gather(1, (mel2ph - 1).clamp(min=0)) + 1 |
|
|
mel2word = mel2word * (mel2ph > 0).long() |
|
|
return mel2word |
|
|
|
|
|
|
|
|
def clip_mel2token_to_multiple(mel2token, frames_multiple): |
|
|
max_frames = mel2token.shape[1] // frames_multiple * frames_multiple |
|
|
mel2token = mel2token[:, :max_frames] |
|
|
return mel2token |
|
|
|
|
|
|
|
|
def expand_states(h, mel2token): |
|
|
h = F.pad(h, [0, 0, 1, 0]) |
|
|
mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]]) |
|
|
h = torch.gather(h, 1, mel2token_) |
|
|
return h |
|
|
|