flexpert / Flexpert-Design /src /datasets /fast_dataloader.py
Honzus24's picture
initial commit
7968cb0
import threading
import torch
import queue
from torch.utils.data import DataLoader
class DataLoaderX(DataLoader):
def __init__(self, local_rank, **kwargs):
super().__init__(**kwargs)
self.stream = torch.cuda.Stream(
local_rank
) # create a new cuda stream in each process
self.local_rank = local_rank
# self.custom_collect_fn = custom_collect_fn
def __iter__(self):
self.iter = super().__iter__()
self.preload()
return self
def preload(self):
while True:
#获取下一个值
self.batch = next(self.iter, None)
if self.batch is not None:
break
if self.iter._send_idx==len(self.iter):
break
if (self.batch is None):
return None
with torch.cuda.stream(self.stream): # 将数据预先放进gpu
for key, val in self.batch.items():
if type(val) == torch.Tensor:
self.batch[key] = val.to(
device=self.local_rank, non_blocking=True
)
def __next__(self):
torch.cuda.current_stream().wait_stream(
self.stream
) # wait tensor to put on GPU
batch = self.batch
# batch = self.custom_collect_fn(self.batch)
if batch is None:
raise StopIteration
self.preload()
return batch