Spaces:
Runtime error
Runtime error
| import contextlib | |
| import unittest | |
| import tempfile | |
| from io import StringIO | |
| import numpy as np | |
| from tests.utils import create_dummy_data, preprocess_lm_data, train_language_model | |
| try: | |
| from pyarrow import plasma | |
| from fairseq.data.plasma_utils import PlasmaView, PlasmaStore | |
| PYARROW_AVAILABLE = True | |
| except ImportError: | |
| PYARROW_AVAILABLE = False | |
| dummy_path = "dummy" | |
| class TestPlasmaView(unittest.TestCase): | |
| def setUp(self) -> None: | |
| self.tmp_file = tempfile.NamedTemporaryFile() # noqa: P201 | |
| self.path = self.tmp_file.name | |
| self.server = PlasmaStore.start(path=self.path, nbytes=10000) | |
| self.client = plasma.connect(self.path, num_retries=10) | |
| def tearDown(self) -> None: | |
| self.client.disconnect() | |
| self.tmp_file.close() | |
| self.server.kill() | |
| def test_two_servers_do_not_share_object_id_space(self): | |
| data_server_1 = np.array([0, 1]) | |
| data_server_2 = np.array([2, 3]) | |
| server_2_path = self.path | |
| with tempfile.NamedTemporaryFile() as server_1_path: | |
| server = PlasmaStore.start(path=server_1_path.name, nbytes=10000) | |
| arr1 = PlasmaView( | |
| data_server_1, dummy_path, 1, plasma_path=server_1_path.name | |
| ) | |
| assert len(arr1.client.list()) == 1 | |
| assert (arr1.array == data_server_1).all() | |
| arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path) | |
| assert (arr2.array == data_server_2).all() | |
| assert (arr1.array == data_server_1).all() | |
| server.kill() | |
| def test_hash_collision(self): | |
| data_server_1 = np.array([0, 1]) | |
| data_server_2 = np.array([2, 3]) | |
| arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path) | |
| assert len(arr1.client.list()) == 1 | |
| arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path) | |
| assert len(arr1.client.list()) == 1 | |
| assert len(arr2.client.list()) == 1 | |
| assert (arr2.array == data_server_1).all() | |
| # New hash key based on tuples | |
| arr3 = PlasmaView( | |
| data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path | |
| ) | |
| assert ( | |
| len(arr2.client.list()) == 2 | |
| ), "No new object was created by using a novel hash key" | |
| assert ( | |
| arr3.object_id in arr2.client.list() | |
| ), "No new object was created by using a novel hash key" | |
| assert ( | |
| arr3.object_id in arr3.client.list() | |
| ), "No new object was created by using a novel hash key" | |
| del arr3, arr2, arr1 | |
| def _assert_view_equal(pv1, pv2): | |
| np.testing.assert_array_equal(pv1.array, pv2.array) | |
| def test_putting_same_array_twice(self): | |
| data = np.array([4, 4, 4]) | |
| arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path) | |
| assert len(self.client.list()) == 1 | |
| arr1b = PlasmaView( | |
| data, dummy_path, 1, plasma_path=self.path | |
| ) # should not change contents of store | |
| arr1c = PlasmaView( | |
| None, dummy_path, 1, plasma_path=self.path | |
| ) # should not change contents of store | |
| assert len(self.client.list()) == 1 | |
| self._assert_view_equal(arr1, arr1b) | |
| self._assert_view_equal(arr1, arr1c) | |
| PlasmaView( | |
| data, dummy_path, 2, plasma_path=self.path | |
| ) # new object id, adds new entry | |
| assert len(self.client.list()) == 2 | |
| new_client = plasma.connect(self.path) | |
| assert len(new_client.list()) == 2 # new client can access same objects | |
| assert isinstance(arr1.object_id, plasma.ObjectID) | |
| del arr1b | |
| del arr1c | |
| def test_plasma_store_full_raises(self): | |
| with tempfile.NamedTemporaryFile() as new_path: | |
| server = PlasmaStore.start(path=new_path.name, nbytes=10000) | |
| with self.assertRaises(plasma.PlasmaStoreFull): | |
| # 2000 floats is more than 2000 bytes | |
| PlasmaView( | |
| np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name | |
| ) | |
| server.kill() | |
| def test_object_id_overflow(self): | |
| PlasmaView.get_object_id("", 2 ** 21) | |
| def test_training_lm_plasma(self): | |
| with contextlib.redirect_stdout(StringIO()): | |
| with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: | |
| create_dummy_data(data_dir) | |
| preprocess_lm_data(data_dir) | |
| train_language_model( | |
| data_dir, | |
| "transformer_lm", | |
| ["--use-plasma-view", "--plasma-path", self.path], | |
| run_validation=True, | |
| ) | |