Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import subprocess | |
| import tempfile | |
| class PlasmaArray(object): | |
| """ | |
| Wrapper around numpy arrays that automatically moves the data to shared | |
| memory upon serialization. This is particularly helpful when passing numpy | |
| arrays through multiprocessing, so that data is not unnecessarily | |
| duplicated or pickled. | |
| """ | |
| def __init__(self, array): | |
| super().__init__() | |
| self.array = array | |
| self.disable = array.nbytes < 134217728 # disable for arrays <128MB | |
| self.object_id = None | |
| self.path = None | |
| # variables with underscores shouldn't be pickled | |
| self._client = None | |
| self._server = None | |
| self._server_tmp = None | |
| self._plasma = None | |
| def plasma(self): | |
| if self._plasma is None and not self.disable: | |
| try: | |
| import pyarrow.plasma as plasma | |
| self._plasma = plasma | |
| except ImportError: | |
| self._plasma = None | |
| return self._plasma | |
| def start_server(self): | |
| if self.plasma is None or self._server is not None: | |
| return | |
| assert self.object_id is None | |
| assert self.path is None | |
| self._server_tmp = tempfile.NamedTemporaryFile() | |
| self.path = self._server_tmp.name | |
| self._server = subprocess.Popen( | |
| [ | |
| "plasma_store", | |
| "-m", | |
| str(int(1.05 * self.array.nbytes)), | |
| "-s", | |
| self.path, | |
| ] | |
| ) | |
| def client(self): | |
| if self._client is None: | |
| assert self.path is not None | |
| self._client = self.plasma.connect(self.path, num_retries=200) | |
| return self._client | |
| def __getstate__(self): | |
| if self.plasma is None: | |
| return self.__dict__ | |
| if self.object_id is None: | |
| self.start_server() | |
| self.object_id = self.client.put(self.array) | |
| state = self.__dict__.copy() | |
| del state["array"] | |
| state["_client"] = None | |
| state["_server"] = None | |
| state["_server_tmp"] = None | |
| state["_plasma"] = None | |
| return state | |
| def __setstate__(self, state): | |
| self.__dict__.update(state) | |
| if self.plasma is None: | |
| return | |
| self.array = self.client.get(self.object_id) | |
| def __del__(self): | |
| if self._server is not None: | |
| self._server.kill() | |
| self._server = None | |
| self._server_tmp.close() | |
| self._server_tmp = None | |