|
|
import unittest |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
import hypothesis.strategies as st |
|
|
from hypothesis import assume, given, settings |
|
|
from torch.testing._internal.common_utils import TestCase |
|
|
from examples.simultaneous_translation.utils.functions import exclusive_cumprod |
|
|
|
|
|
|
|
|
TEST_CUDA = torch.cuda.is_available() |
|
|
|
|
|
|
|
|
class AlignmentTrainTest(TestCase): |
|
|
def _test_custom_alignment_train_ref(self, p_choose, eps): |
|
|
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps) |
|
|
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0) |
|
|
|
|
|
bsz = p_choose.size(0) |
|
|
tgt_len = p_choose.size(1) |
|
|
src_len = p_choose.size(2) |
|
|
|
|
|
alpha_0 = p_choose.new_zeros([bsz, 1, src_len]) |
|
|
alpha_0[:, :, 0] = 1.0 |
|
|
|
|
|
previous_alpha = [alpha_0] |
|
|
|
|
|
for i in range(tgt_len): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alpha_i = ( |
|
|
p_choose[:, i] |
|
|
* cumprod_1mp[:, i] |
|
|
* torch.cumsum( |
|
|
previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1 |
|
|
) |
|
|
).clamp(0, 1.0) |
|
|
|
|
|
previous_alpha.append(alpha_i.unsqueeze(1)) |
|
|
|
|
|
|
|
|
alpha = torch.cat(previous_alpha[1:], dim=1) |
|
|
return alpha |
|
|
|
|
|
def _test_custom_alignment_train_impl(self, p_choose, alpha, eps): |
|
|
if p_choose.is_cuda: |
|
|
from alignment_train_cuda_binding import alignment_train_cuda |
|
|
alignment_train_cuda(p_choose, alpha, eps) |
|
|
else: |
|
|
from alignment_train_cpu_binding import alignment_train_cpu |
|
|
alignment_train_cpu(p_choose, alpha, eps) |
|
|
|
|
|
@settings(deadline=None) |
|
|
@given( |
|
|
bsz=st.integers(1, 100), |
|
|
tgt_len=st.integers(1, 100), |
|
|
src_len=st.integers(1, 550), |
|
|
device=st.sampled_from(["cpu", "cuda"]), |
|
|
) |
|
|
def test_alignment_train(self, bsz, tgt_len, src_len, device): |
|
|
eps = 1e-6 |
|
|
|
|
|
assume(device == "cpu" or TEST_CUDA) |
|
|
p_choose = torch.rand(bsz, tgt_len, src_len, device=device) |
|
|
|
|
|
|
|
|
alpha_act = p_choose.new_zeros([bsz, tgt_len, src_len]) |
|
|
self._test_custom_alignment_train_impl(p_choose, alpha_act, eps) |
|
|
|
|
|
|
|
|
alpha_ref = self._test_custom_alignment_train_ref(p_choose, eps) |
|
|
|
|
|
|
|
|
alpha_act = alpha_act.cpu().detach().numpy() |
|
|
alpha_ref = alpha_ref.cpu().detach().numpy() |
|
|
np.testing.assert_allclose( |
|
|
alpha_act, |
|
|
alpha_ref, |
|
|
atol=1e-3, |
|
|
rtol=1e-3, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
unittest.main() |
|
|
|