Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| def maximum_path(soft_attention, mask): | |
| """ | |
| :param soft_attention: [b, t_x, t_y] | |
| :param mask: [b, t_x, t_y] | |
| :return: attn: [b, t_x, t_y] | |
| """ | |
| b, t_x, t_y = soft_attention.size() | |
| device = soft_attention.device | |
| log_p = torch.zeros(b, t_x, t_y).to(device) | |
| log_p[:, 0, :] = torch.cumsum(soft_attention[:, 0, :], dim=1) | |
| log_p[:, :, 0] = torch.cumsum(soft_attention[:, :, 0], dim=1) | |
| for i in range(1, t_x): | |
| for j in range(1, t_y): | |
| max_prev = torch.max(log_p[:, i - 1, j], log_p[:, i, j - 1]) | |
| log_p[:, i, j] = max_prev + soft_attention[:, i, j] | |
| path = torch.zeros_like(soft_attention) | |
| for b_idx in range(b): | |
| i = t_x - 1 | |
| j = t_y - 1 | |
| while i > 0 and j > 0: | |
| path[b_idx, i, j] = 1 | |
| if log_p[b_idx, i - 1, j] > log_p[b_idx, i, j - 1]: | |
| i -= 1 | |
| else: | |
| j -= 1 | |
| path[b_idx, i, j] = 1 | |
| return path * mask | |