Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import hashlib | |
| import json | |
| import subprocess | |
| import tempfile | |
| from typing import Hashable | |
| try: | |
| import pyarrow.plasma as plasma | |
| PYARROW_AVAILABLE = True | |
| except ImportError: | |
| plasma = None | |
| PYARROW_AVAILABLE = False | |
| class PlasmaArray: | |
| """ | |
| Wrapper around numpy arrays that automatically moves the data to shared | |
| memory upon serialization. This is particularly helpful when passing numpy | |
| arrays through multiprocessing, so that data is not unnecessarily | |
| duplicated or pickled. | |
| """ | |
| def __init__(self, array): | |
| super().__init__() | |
| self.array = array | |
| self.disable = array.nbytes < 134217728 # disable for arrays <128MB | |
| self.object_id = None | |
| self.path = None | |
| # variables with underscores shouldn't be pickled | |
| self._client = None | |
| self._server = None | |
| self._server_tmp = None | |
| self._plasma = None | |
| def plasma(self): | |
| if self._plasma is None and not self.disable: | |
| self._plasma = plasma | |
| return self._plasma | |
| def start_server(self): | |
| if self.plasma is None or self._server is not None: | |
| return | |
| assert self.object_id is None | |
| assert self.path is None | |
| self._server_tmp = tempfile.NamedTemporaryFile() | |
| self.path = self._server_tmp.name | |
| self._server = subprocess.Popen( | |
| ["plasma_store", "-m", str(int(1.05 * self.array.nbytes)), "-s", self.path] | |
| ) | |
| def client(self): | |
| if self._client is None: | |
| assert self.path is not None | |
| self._client = self.plasma.connect(self.path, num_retries=200) | |
| return self._client | |
| def __getstate__(self): | |
| """Called on pickle load""" | |
| if self.plasma is None: | |
| return self.__dict__ | |
| if self.object_id is None: | |
| self.start_server() | |
| self.object_id = self.client.put(self.array) | |
| state = self.__dict__.copy() | |
| del state["array"] | |
| state["_client"] = None | |
| state["_server"] = None | |
| state["_server_tmp"] = None | |
| state["_plasma"] = None | |
| return state | |
| def __setstate__(self, state): | |
| """Called on pickle save""" | |
| self.__dict__.update(state) | |
| if self.plasma is None: | |
| return | |
| self.array = self.client.get(self.object_id) | |
| def __del__(self): | |
| if self._server is not None: | |
| self._server.kill() | |
| self._server = None | |
| self._server_tmp.close() | |
| self._server_tmp = None | |
| DEFAULT_PLASMA_PATH = "/tmp/plasma" | |
| class PlasmaView: | |
| """Interface to write and read from shared memory. Whereas PlasmaArray writes to plasma on serialization, | |
| PlasmaView writes to shared memory on instantiation.""" | |
| def __init__(self, array, split_path: str, hash_data: Hashable, plasma_path=None): | |
| """ | |
| Args: | |
| array: numpy array to store. This can be read with ``PlasmaView().array`` | |
| split_path: the path whence the data was read, used for hashing | |
| hash_data: other metadata about the array that can be used to create a unique key. | |
| as of writing, the 3 callers in ``TokenBlockDataset`` use:: | |
| hash_data = ((block_size, document_sep_len, str(break_mode), len(dataset)), 0|1|2) | |
| """ | |
| assert PYARROW_AVAILABLE | |
| assert split_path is not None | |
| if plasma_path is None: | |
| plasma_path = DEFAULT_PLASMA_PATH | |
| self.path = plasma_path | |
| self.split_path = split_path | |
| self._client = None # Initialize lazily for pickle. plasma clients should not be deep copied or serialized. | |
| self._n = None | |
| self.object_id = self.get_object_id(self.split_path, hash_data) | |
| try: | |
| self.client.put(array, object_id=self.object_id) | |
| except plasma.PlasmaObjectExists: | |
| pass | |
| def client(self): | |
| if self._client is None: | |
| self._client = plasma.connect(self.path, num_retries=200) | |
| return self._client | |
| def array(self): | |
| """Fetch a read only view of an np.array, stored in plasma.""" | |
| ret = self.client.get(self.object_id) | |
| return ret | |
| def get_object_id(split_path: str, hash_data: Hashable): | |
| """Returns plasma.ObjectID from hashing split_path and object_num.""" | |
| hash = hashlib.blake2b(bytes(split_path, "utf-8"), digest_size=20) | |
| harg = json.dumps(hash_data).encode("utf-8") | |
| hash.update(harg) | |
| return plasma.ObjectID(hash.digest()) | |
| def __getstate__(self): | |
| """Called on pickle save""" | |
| self.disconnect() | |
| state = self.__dict__.copy() | |
| assert state["_client"] is None | |
| assert "object_id" in state | |
| return state | |
| def __setstate__(self, state): | |
| """Called on pickle load""" | |
| self.__dict__.update(state) | |
| def __del__(self): | |
| self.disconnect() | |
| def disconnect(self): | |
| if self._client is not None: | |
| self._client.disconnect() | |
| self._client = None | |
| def __len__(self): | |
| """Save reads by caching len""" | |
| if self._n is None: | |
| self._n = len(self.array) | |
| return self._n | |
| GB100 = (1024**3) * 100 | |
| class PlasmaStore: | |
| def __init__(self, path=DEFAULT_PLASMA_PATH, nbytes: int = GB100): | |
| self.server = self.start(path, nbytes) | |
| def __del__(self): | |
| self.server.kill() | |
| def start(path=DEFAULT_PLASMA_PATH, nbytes: int = GB100) -> subprocess.Popen: | |
| if not PYARROW_AVAILABLE: | |
| raise ImportError("please run pip install pyarrow to use --use_plasma_view") | |
| # best practice is to allocate more space than we need. The limitation seems to be the size of /dev/shm | |
| _server = subprocess.Popen(["plasma_store", "-m", str(nbytes), "-s", path]) | |
| plasma.connect(path, num_retries=200) # If we can't connect we fail immediately | |
| return _server | |