|
|
import torch |
|
|
import threading |
|
|
import pickle |
|
|
|
|
|
from torch.utils.data import IterDataPipe, communication, MapDataPipe |
|
|
|
|
|
try: |
|
|
import dill |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dill.extend(use_dill=False) |
|
|
HAS_DILL = True |
|
|
except ImportError: |
|
|
HAS_DILL = False |
|
|
|
|
|
__all__ = [ |
|
|
"DataPipeToQueuesLoop", |
|
|
"SpawnProcessForDataPipeline", |
|
|
"SpawnThreadForDataPipeline", |
|
|
] |
|
|
|
|
|
def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue): |
|
|
if isinstance(source_datapipe, IterDataPipe): |
|
|
pipe_type = communication.iter |
|
|
protocol_type = communication.protocol.IterDataPipeQueueProtocolServer |
|
|
elif isinstance(source_datapipe, MapDataPipe): |
|
|
pipe_type = communication.map |
|
|
protocol_type = communication.protocol.MapDataPipeQueueProtocolServer |
|
|
else: |
|
|
raise Exception('Only supports IterDataPipe or MapDataPipe, got', source_datapipe) |
|
|
|
|
|
torch.set_num_threads(1) |
|
|
for _ in pipe_type.DataPipeBehindQueues(source_datapipe, protocol_type(req_queue, res_queue), |
|
|
blocking_request_get=True): |
|
|
pass |
|
|
|
|
|
|
|
|
def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe): |
|
|
req_queue = multiprocessing_ctx.Queue() |
|
|
res_queue = multiprocessing_ctx.Queue() |
|
|
process = multiprocessing_ctx.Process( |
|
|
target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue)) |
|
|
return process, req_queue, res_queue |
|
|
|
|
|
|
|
|
def SpawnThreadForDataPipeline(datapipe): |
|
|
r""" |
|
|
Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with DataPipeToQueuesLoop as target, |
|
|
and return the process, req_queue, res_queue, thread_local_datapipe. |
|
|
""" |
|
|
req_queue = communication.queue.ThreadingQueue() |
|
|
res_queue = communication.queue.ThreadingQueue() |
|
|
|
|
|
try: |
|
|
new_datapipe = pickle.loads(pickle.dumps(datapipe)) |
|
|
except Exception as pe: |
|
|
if HAS_DILL: |
|
|
try: |
|
|
new_datapipe = dill.loads(dill.dumps(datapipe)) |
|
|
except Exception as de: |
|
|
raise Exception('Unable to dill DataPipe to make thread local copy', de) |
|
|
|
|
|
else: |
|
|
raise Exception('Unable to pickle DataPipe to make thread local copy (consider installing `dill`)', pe) |
|
|
|
|
|
process = threading.Thread(target=DataPipeToQueuesLoop, args=( |
|
|
new_datapipe, req_queue, res_queue), daemon=True) |
|
|
return process, req_queue, res_queue, new_datapipe |
|
|
|