| | |
| |
|
| | import hypothesis.strategies as st |
| | import numpy as np |
| | import torch |
| | from caffe2.python import core |
| | from caffe2.python.test_util import TestCase |
| | from hypothesis import given, settings |
| | from torch import nn |
| |
|
| |
|
| | class TestC2LSTM(TestCase): |
| | @given( |
| | bsz=st.integers(1, 5), |
| | seq_lens=st.integers(1, 6), |
| | emb_lens=st.integers(5, 10), |
| | hidden_size=st.integers(3, 7), |
| | num_layers=st.integers(1, 4), |
| | has_biases=st.booleans(), |
| | is_bidirectional=st.booleans(), |
| | batch_first=st.booleans(), |
| | ) |
| | @settings(deadline=10000) |
| | def test_c2_lstm( |
| | self, |
| | bsz, |
| | seq_lens, |
| | emb_lens, |
| | hidden_size, |
| | num_layers, |
| | has_biases, |
| | is_bidirectional, |
| | batch_first, |
| | ): |
| | net = core.Net("test_net") |
| | num_directions = 2 if is_bidirectional else 1 |
| | py_lstm = nn.LSTM( |
| | emb_lens, |
| | hidden_size, |
| | batch_first=batch_first, |
| | bidirectional=is_bidirectional, |
| | bias=has_biases, |
| | num_layers=num_layers, |
| | ) |
| |
|
| | hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32) |
| |
|
| | if batch_first: |
| | inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32) |
| | else: |
| | inputs = np.random.randn(seq_lens, bsz, emb_lens).astype(np.float32) |
| |
|
| | py_results = py_lstm(torch.from_numpy(inputs)) |
| | lstm_in = [ |
| | torch.from_numpy(inputs), |
| | torch.from_numpy(hx), |
| | torch.from_numpy(hx), |
| | ] + [param.detach() for param in py_lstm._flat_weights] |
| |
|
| | c2_results = torch.ops._caffe2.InferenceLSTM( |
| | lstm_in, num_layers, has_biases, batch_first, is_bidirectional |
| | ) |
| |
|
| | np.testing.assert_array_almost_equal( |
| | py_results[0].detach().numpy(), c2_results[0].detach().numpy() |
| | ) |
| | np.testing.assert_array_almost_equal( |
| | py_results[1][0].detach().numpy(), c2_results[1].detach().numpy() |
| | ) |
| | np.testing.assert_array_almost_equal( |
| | py_results[1][1].detach().numpy(), c2_results[2].detach().numpy() |
| | ) |
| |
|