style-bert-vits2-fastapi / src /sbv2 /monotonic_align.py
buchi-stdesign's picture
Upload 18 files
1ee91f8 verified
import torch
import torch.nn.functional as F
@torch.jit.script
def maximum_path(soft_attention, mask):
"""
:param soft_attention: [b, t_x, t_y]
:param mask: [b, t_x, t_y]
:return: attn: [b, t_x, t_y]
"""
b, t_x, t_y = soft_attention.size()
device = soft_attention.device
log_p = torch.zeros(b, t_x, t_y).to(device)
log_p[:, 0, :] = torch.cumsum(soft_attention[:, 0, :], dim=1)
log_p[:, :, 0] = torch.cumsum(soft_attention[:, :, 0], dim=1)
for i in range(1, t_x):
for j in range(1, t_y):
max_prev = torch.max(log_p[:, i - 1, j], log_p[:, i, j - 1])
log_p[:, i, j] = max_prev + soft_attention[:, i, j]
path = torch.zeros_like(soft_attention)
for b_idx in range(b):
i = t_x - 1
j = t_y - 1
while i > 0 and j > 0:
path[b_idx, i, j] = 1
if log_p[b_idx, i - 1, j] > log_p[b_idx, i, j - 1]:
i -= 1
else:
j -= 1
path[b_idx, i, j] = 1
return path * mask