| | import contextlib |
| | import unittest |
| | import tempfile |
| | from io import StringIO |
| |
|
| | import numpy as np |
| |
|
| | from tests.utils import create_dummy_data, preprocess_lm_data, train_language_model |
| |
|
| | try: |
| | from pyarrow import plasma |
| | from fairseq.data.plasma_utils import PlasmaView, PlasmaStore |
| |
|
| | PYARROW_AVAILABLE = True |
| | except ImportError: |
| | PYARROW_AVAILABLE = False |
| |
|
| | dummy_path = "dummy" |
| |
|
| |
|
| | @unittest.skipUnless(PYARROW_AVAILABLE, "") |
| | class TestPlasmaView(unittest.TestCase): |
| | def setUp(self) -> None: |
| | self.tmp_file = tempfile.NamedTemporaryFile() |
| | self.path = self.tmp_file.name |
| | self.server = PlasmaStore.start(path=self.path, nbytes=10000) |
| | self.client = plasma.connect(self.path, num_retries=10) |
| |
|
| | def tearDown(self) -> None: |
| | self.client.disconnect() |
| | self.tmp_file.close() |
| | self.server.kill() |
| |
|
| | def test_two_servers_do_not_share_object_id_space(self): |
| | data_server_1 = np.array([0, 1]) |
| | data_server_2 = np.array([2, 3]) |
| | server_2_path = self.path |
| | with tempfile.NamedTemporaryFile() as server_1_path: |
| | server = PlasmaStore.start(path=server_1_path.name, nbytes=10000) |
| | arr1 = PlasmaView( |
| | data_server_1, dummy_path, 1, plasma_path=server_1_path.name |
| | ) |
| | assert len(arr1.client.list()) == 1 |
| | assert (arr1.array == data_server_1).all() |
| | arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path) |
| | assert (arr2.array == data_server_2).all() |
| | assert (arr1.array == data_server_1).all() |
| | server.kill() |
| |
|
| | def test_hash_collision(self): |
| | data_server_1 = np.array([0, 1]) |
| | data_server_2 = np.array([2, 3]) |
| | arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path) |
| | assert len(arr1.client.list()) == 1 |
| | arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path) |
| | assert len(arr1.client.list()) == 1 |
| | assert len(arr2.client.list()) == 1 |
| | assert (arr2.array == data_server_1).all() |
| | |
| | arr3 = PlasmaView( |
| | data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path |
| | ) |
| | assert ( |
| | len(arr2.client.list()) == 2 |
| | ), "No new object was created by using a novel hash key" |
| | assert ( |
| | arr3.object_id in arr2.client.list() |
| | ), "No new object was created by using a novel hash key" |
| | assert ( |
| | arr3.object_id in arr3.client.list() |
| | ), "No new object was created by using a novel hash key" |
| | del arr3, arr2, arr1 |
| |
|
| | @staticmethod |
| | def _assert_view_equal(pv1, pv2): |
| | np.testing.assert_array_equal(pv1.array, pv2.array) |
| |
|
| | def test_putting_same_array_twice(self): |
| | data = np.array([4, 4, 4]) |
| | arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path) |
| | assert len(self.client.list()) == 1 |
| | arr1b = PlasmaView( |
| | data, dummy_path, 1, plasma_path=self.path |
| | ) |
| | arr1c = PlasmaView( |
| | None, dummy_path, 1, plasma_path=self.path |
| | ) |
| |
|
| | assert len(self.client.list()) == 1 |
| | self._assert_view_equal(arr1, arr1b) |
| | self._assert_view_equal(arr1, arr1c) |
| | PlasmaView( |
| | data, dummy_path, 2, plasma_path=self.path |
| | ) |
| | assert len(self.client.list()) == 2 |
| |
|
| | new_client = plasma.connect(self.path) |
| | assert len(new_client.list()) == 2 |
| | assert isinstance(arr1.object_id, plasma.ObjectID) |
| | del arr1b |
| | del arr1c |
| |
|
| | def test_plasma_store_full_raises(self): |
| | with tempfile.NamedTemporaryFile() as new_path: |
| | server = PlasmaStore.start(path=new_path.name, nbytes=10000) |
| | with self.assertRaises(plasma.PlasmaStoreFull): |
| | |
| | PlasmaView( |
| | np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name |
| | ) |
| | server.kill() |
| |
|
| | def test_object_id_overflow(self): |
| | PlasmaView.get_object_id("", 2 ** 21) |
| |
|
| | def test_training_lm_plasma(self): |
| | with contextlib.redirect_stdout(StringIO()): |
| | with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: |
| | create_dummy_data(data_dir) |
| | preprocess_lm_data(data_dir) |
| | train_language_model( |
| | data_dir, |
| | "transformer_lm", |
| | ["--use-plasma-view", "--plasma-path", self.path], |
| | run_validation=True, |
| | ) |
| |
|