| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import numpy as np |
| import pytest |
| import torch |
| from torch.nn import CTCLoss as CTCLoss_Pytorch |
|
|
| DEVICES = ['cpu'] |
|
|
| if torch.cuda.is_available(): |
| DEVICES.append('cuda') |
|
|
|
|
| def wrap_and_call(fn, acts, labels, device): |
| if not torch.is_tensor(acts): |
| acts = torch.FloatTensor(acts) |
|
|
| if 'cuda' in device: |
| acts = acts.cuda() |
|
|
| if not acts.requires_grad: |
| acts.requires_grad = True |
|
|
| lengths = [acts.shape[1]] * acts.shape[0] |
| label_lengths = [len(l) for l in labels] |
| labels = torch.LongTensor(labels) |
| lengths = torch.LongTensor(lengths) |
| label_lengths = torch.LongTensor(label_lengths) |
| log_probs = torch.nn.functional.log_softmax(acts.transpose(0, 1), -1) |
| if 'cuda' in device: |
| labels = labels.cuda() |
| lengths = lengths.cuda() |
| label_lengths = label_lengths.cuda() |
|
|
| costs = fn(log_probs, labels, lengths, label_lengths) |
| cost = torch.sum(costs) |
| cost.backward() |
|
|
| if 'cuda' in device: |
| torch.cuda.synchronize() |
|
|
| if acts.grad is not None: |
| grad = acts.grad.data.cpu().numpy() |
| else: |
| grad = None |
|
|
| return costs.data.cpu().numpy(), grad |
|
|
|
|
| def init_k2_ctc(**kwargs): |
| from nemo.collections.asr.parts.k2.ml_loss import CtcLoss |
|
|
| ctc = CtcLoss(**kwargs) |
| return lambda log_probs, labels, lengths, label_lengths: ctc( |
| log_probs.transpose(0, 1), labels, lengths, label_lengths |
| )[0] |
|
|
|
|
| def skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled): |
| if device == 'cpu': |
| supported, msg = k2_is_appropriate |
| elif device == 'cuda': |
| supported, msg = k2_cuda_is_enabled |
| else: |
| raise ValueError(f"Unknown device: {device}") |
| if not supported: |
| pytest.skip(f"k2 test is skipped. Reason : {msg}") |
|
|
|
|
| class TestCTCLossK2: |
| @pytest.mark.unit |
| @pytest.mark.parametrize('device', DEVICES) |
| def test_case_small(self, device, k2_is_appropriate, k2_cuda_is_enabled): |
| skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) |
|
|
| acts = np.array( |
| [ |
| [ |
| [0.1, 0.6, 0.1, 0.1, 0.1], |
| [0.1, 0.1, 0.6, 0.1, 0.1], |
| [0.1, 0.1, 0.2, 0.8, 0.1], |
| [0.1, 0.6, 0.1, 0.1, 0.1], |
| [0.1, 0.1, 0.2, 0.1, 0.1], |
| [0.7, 0.1, 0.2, 0.1, 0.1], |
| ] |
| ] |
| ) |
| labels = [[1, 2, 3]] |
|
|
| fn_k2 = init_k2_ctc(num_classes=acts.shape[-1], blank=0, reduction='sum') |
| k2_cost, k2_grads = wrap_and_call(fn_k2, acts, labels, device) |
|
|
| expected_cost = 5.0279555 |
| expected_grads = np.array( |
| [ |
| [ |
| [0.00157518, -0.53266853, 0.17703111, 0.17703111, 0.17703111], |
| [-0.02431531, -0.17048728, -0.15925968, 0.17703113, 0.17703113], |
| [-0.06871005, 0.03236287, -0.2943067, 0.16722652, 0.16342735], |
| [-0.09178554, 0.25313747, -0.17673965, -0.16164337, 0.17703108], |
| [-0.10229809, 0.19587973, 0.05823242, -0.34769377, 0.19587973], |
| [-0.22203964, 0.1687112, 0.18645471, -0.30183747, 0.1687112], |
| ] |
| ] |
| ) |
|
|
| assert np.allclose(k2_cost, expected_cost, rtol=1e-6), "small_test costs mismatch." |
| assert np.allclose(k2_grads, expected_grads, atol=1e-6), "small_test gradient mismatch." |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize('device', DEVICES) |
| def test_case_small_blank_last(self, device, k2_is_appropriate, k2_cuda_is_enabled): |
| skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) |
|
|
| acts = np.array( |
| [ |
| [ |
| [0.0, 1.0, 3.0], |
| [0.0, 2.0, 3.0], |
| [1.0, 1.0, 3.0], |
| [2.0, 3.0, 2.0], |
| [0.0, 0.0, 1.0], |
| [0.0, 1.0, 1.0], |
| [1.0, 0.0, 1.0], |
| [2.0, 2.0, 0.0], |
| [0.0, 2.0, 5.0], |
| [0.0, 3.0, 5.0], |
| [1.0, 2.0, 5.0], |
| [2.0, 4.0, 4.0], |
| [0.0, 3.0, 4.0], |
| [0.0, 4.0, 4.0], |
| [1.0, 3.0, 4.0], |
| [2.0, 5.0, 3.0], |
| [2.0, 2.0, 1.0], |
| [2.0, 3.0, 1.0], |
| [3.0, 2.0, 1.0], |
| [4.0, 4.0, 0.0], |
| ] |
| ] |
| ) |
| labels = [[0, 1, 0, 0, 1, 0]] |
|
|
| fn_k2 = init_k2_ctc(num_classes=acts.shape[-1], blank=acts.shape[-1] - 1, reduction='sum') |
| k2_cost, k2_grads = wrap_and_call(fn_k2, acts, labels, device) |
|
|
| expected_cost = 6.823422 |
| expected_grads = np.array( |
| [ |
| [ |
| [-0.09792291, 0.11419516, -0.01627225], |
| [-0.08915664, 0.22963384, -0.14047718], |
| [-0.19687234, 0.06477807, 0.13209426], |
| [-0.22838503, 0.1980845, 0.03030053], |
| [-0.07985485, -0.0589368, 0.13879165], |
| [-0.04722299, 0.01424287, 0.03298012], |
| [0.01492161, 0.02710512, -0.04202673], |
| [-0.43219852, 0.4305843, 0.00161422], |
| [-0.00332598, 0.0440818, -0.04075582], |
| [-0.01329869, 0.11521607, -0.10191737], |
| [-0.03721291, 0.04389342, -0.00668051], |
| [-0.2723349, 0.43273386, -0.16039898], |
| [-0.03499417, 0.1896997, -0.15470551], |
| [-0.02911933, 0.29706067, -0.26794133], |
| [-0.04593367, -0.04479058, 0.09072424], |
| [-0.07227867, 0.16096972, -0.08869105], |
| [0.13993078, -0.20230117, 0.06237038], |
| [-0.05889719, 0.04007925, 0.01881794], |
| [-0.09667239, 0.07077749, 0.0258949], |
| [-0.49002117, 0.4954626, -0.00544143], |
| ] |
| ] |
| ) |
|
|
| assert np.allclose(k2_cost, expected_cost, rtol=1e-6), "small_test_blank_last costs mismatch." |
| assert np.allclose(k2_grads, expected_grads, atol=1e-6), "small_test_blank_last gradient mismatch." |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize('device', DEVICES) |
| def test_case_small_random(self, device, k2_is_appropriate, k2_cuda_is_enabled): |
| skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) |
|
|
| rng = np.random.RandomState(0) |
| acts = rng.randn(1, 4, 3) |
| labels = [[1, 2]] |
|
|
| fn_k2 = init_k2_ctc(num_classes=acts.shape[-1], blank=0, reduction='sum') |
| k2_cost, k2_grads = wrap_and_call(fn_k2, acts, labels, device) |
|
|
| fn_pt = CTCLoss_Pytorch(reduction='sum', zero_infinity=True) |
| pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) |
|
|
| assert np.allclose(k2_cost, pt_cost, rtol=1e-6), "small_random_test costs mismatch." |
| assert np.allclose(k2_grads, pt_grads, atol=1e-6), "small_random_test gradient mismatch." |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize('device', DEVICES) |
| def test_case_big_tensor(self, device, k2_is_appropriate, k2_cuda_is_enabled): |
| skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) |
|
|
| |
| acts = [ |
| [ |
| [0.06535690384862791, 0.7875301411923206, 0.08159176605666074], |
| [0.5297155426466327, 0.7506749639230854, 0.7541348379087998], |
| [0.6097641124736383, 0.8681404965673826, 0.6225318186056529], |
| [0.6685222872103057, 0.8580392805336061, 0.16453892311765583], |
| [0.989779515236694, 0.944298460961015, 0.6031678586829663], |
| [0.9467833543605416, 0.666202507295747, 0.28688179752461884], |
| [0.09418426230195986, 0.3666735970751962, 0.736168049462793], |
| [0.1666804425271342, 0.7141542198635192, 0.3993997272216727], |
| [0.5359823524146038, 0.29182076440286386, 0.6126422611507932], |
| [0.3242405528768486, 0.8007644367291621, 0.5241057606558068], |
| [0.779194617063042, 0.18331417220174862, 0.113745182072432], |
| [0.24022162381327106, 0.3394695622533106, 0.1341595066017014], |
| ], |
| [ |
| [0.5055615569388828, 0.051597282072282646, 0.6402903936686337], |
| [0.43073311517251, 0.8294731834714112, 0.1774668847323424], |
| [0.3207001991262245, 0.04288308912457006, 0.30280282975568984], |
| [0.6751777088333762, 0.569537369330242, 0.5584738347504452], |
| [0.08313242153985256, 0.06016544344162322, 0.10795752845152584], |
| [0.7486153608562472, 0.943918041459349, 0.4863558118797222], |
| [0.4181986264486809, 0.6524078485043804, 0.024242983423721887], |
| [0.13458171554507403, 0.3663418070512402, 0.2958297395361563], |
| [0.9236695822497084, 0.6899291482654177, 0.7418981733448822], |
| [0.25000547599982104, 0.6034295486281007, 0.9872887878887768], |
| [0.5926057265215715, 0.8846724004467684, 0.5434495396894328], |
| [0.6607698886038497, 0.3771277082495921, 0.3580209022231813], |
| ], |
| ] |
|
|
| expected_costs = [6.388067, 5.2999153] |
| expected_grads = [ |
| [ |
| [0.06130501, -0.3107036, 0.24939862], |
| [0.08428053, -0.07131141, -0.01296911], |
| [-0.04510102, 0.21943177, -0.17433074], |
| [-0.1970142, 0.37144178, -0.17442757], |
| [-0.08807078, 0.35828218, -0.2702114], |
| [-0.24209887, 0.33242193, -0.09032306], |
| [-0.07871056, 0.3116736, -0.23296304], |
| [-0.27552277, 0.43320477, -0.157682], |
| [-0.16173504, 0.27361175, -0.1118767], |
| [-0.13012655, 0.42030025, -0.2901737], |
| [-0.2378576, 0.26685005, -0.02899244], |
| [0.08487711, 0.36765888, -0.45253596], |
| ], |
| [ |
| [-0.14147596, -0.2702151, 0.41169107], |
| [-0.05323913, -0.18442528, 0.23766442], |
| [-0.24160458, -0.11692462, 0.3585292], |
| [-0.1004294, -0.17919227, 0.27962166], |
| [-0.01819841, -0.12625945, 0.14445786], |
| [-0.00131121, 0.06060241, -0.0592912], |
| [-0.09093696, 0.2536721, -0.16273515], |
| [-0.08962183, 0.34198248, -0.25236064], |
| [-0.19668606, 0.25176668, -0.05508063], |
| [0.0232805, 0.1351273, -0.1584078], |
| [0.09494846, -0.17026341, 0.07531495], |
| [0.00775955, -0.30424336, 0.29648378], |
| ], |
| ] |
|
|
| acts = np.array(acts) |
| expected_costs = np.array(expected_costs) |
| labels = [[1, 2, 2, 2, 2], [1, 1, 2, 2, 1]] |
|
|
| fn_k2 = init_k2_ctc(num_classes=acts.shape[-1], blank=0, reduction='none') |
| k2_costs, k2_grads = wrap_and_call(fn_k2, acts, labels, device) |
|
|
| assert np.allclose(k2_costs, expected_costs), "big_test average costs mismatch." |
| assert np.allclose(k2_grads, expected_grads, rtol=1e-3), "big_test grads for average cost mismatch." |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize('device', DEVICES) |
| def test_case_large_random(self, device, k2_is_appropriate, k2_cuda_is_enabled): |
| skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) |
|
|
| rng = np.random.RandomState(0) |
| acts = rng.randn(4, 80, 5) |
| labels = [ |
| [1, 2, 4, 3, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 3, 3, 1, 1, 1], |
| [3, 2, 2, 3, 4, 1, 1, 1, 1, 1, 4, 4, 1, 2, 1, 3, 4, 3, 1, 2], |
| [4, 4, 1, 2, 1, 3, 4, 3, 1, 2, 3, 2, 2, 3, 4, 1, 1, 1, 1, 1], |
| [1, 1, 2, 1, 2, 3, 3, 1, 1, 1, 1, 2, 4, 3, 2, 2, 1, 1, 1, 1], |
| ] |
|
|
| fn_k2 = init_k2_ctc(num_classes=acts.shape[-1], blank=0, reduction='sum') |
| k2_costs, k2_grads = wrap_and_call(fn_k2, acts, labels, device) |
|
|
| fn_pt = CTCLoss_Pytorch(reduction='sum', zero_infinity=True) |
| pt_costs, pt_grads = wrap_and_call(fn_pt, acts, labels, device) |
|
|
| assert np.allclose(k2_costs, pt_costs, atol=1e-5, rtol=1e-3), "large_random_test costs mismatch." |
| assert np.allclose(k2_grads, pt_grads, atol=1e-5, rtol=1e-3), "large_random_test gradient mismatch." |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__]) |
|
|