Spaces:
Running
on
Zero
Running
on
Zero
| import threading | |
| import torch | |
| import queue | |
| from torch.utils.data import DataLoader | |
| class DataLoaderX(DataLoader): | |
| def __init__(self, local_rank, **kwargs): | |
| super().__init__(**kwargs) | |
| self.stream = torch.cuda.Stream( | |
| local_rank | |
| ) # create a new cuda stream in each process | |
| self.local_rank = local_rank | |
| # self.custom_collect_fn = custom_collect_fn | |
| def __iter__(self): | |
| self.iter = super().__iter__() | |
| self.preload() | |
| return self | |
| def preload(self): | |
| while True: | |
| #获取下一个值 | |
| self.batch = next(self.iter, None) | |
| if self.batch is not None: | |
| break | |
| if self.iter._send_idx==len(self.iter): | |
| break | |
| if (self.batch is None): | |
| return None | |
| with torch.cuda.stream(self.stream): # 将数据预先放进gpu | |
| for key, val in self.batch.items(): | |
| if type(val) == torch.Tensor: | |
| self.batch[key] = val.to( | |
| device=self.local_rank, non_blocking=True | |
| ) | |
| def __next__(self): | |
| torch.cuda.current_stream().wait_stream( | |
| self.stream | |
| ) # wait tensor to put on GPU | |
| batch = self.batch | |
| # batch = self.custom_collect_fn(self.batch) | |
| if batch is None: | |
| raise StopIteration | |
| self.preload() | |
| return batch | |