Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,475 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import threading
import torch
import queue
from torch.utils.data import DataLoader
class DataLoaderX(DataLoader):
def __init__(self, local_rank, **kwargs):
super().__init__(**kwargs)
self.stream = torch.cuda.Stream(
local_rank
) # create a new cuda stream in each process
self.local_rank = local_rank
# self.custom_collect_fn = custom_collect_fn
def __iter__(self):
self.iter = super().__iter__()
self.preload()
return self
def preload(self):
while True:
#获取下一个值
self.batch = next(self.iter, None)
if self.batch is not None:
break
if self.iter._send_idx==len(self.iter):
break
if (self.batch is None):
return None
with torch.cuda.stream(self.stream): # 将数据预先放进gpu
for key, val in self.batch.items():
if type(val) == torch.Tensor:
self.batch[key] = val.to(
device=self.local_rank, non_blocking=True
)
def __next__(self):
torch.cuda.current_stream().wait_stream(
self.stream
) # wait tensor to put on GPU
batch = self.batch
# batch = self.custom_collect_fn(self.batch)
if batch is None:
raise StopIteration
self.preload()
return batch
|