| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
|
|
|
|
| class DataPrefetcher: |
| """ |
| DataPrefetcher is inspired by code of following file: |
| https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py |
| It could speedup your pytorch dataloader. For more information, please check |
| https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789. |
| """ |
|
|
| def __init__(self, loader): |
| self.loader = iter(loader) |
| self.stream = torch.cuda.Stream() |
| self.input_cuda = self._input_cuda_for_image |
| self.record_stream = DataPrefetcher._record_stream_for_image |
| self.preload() |
|
|
| def preload(self): |
| try: |
| self.next_input, self.next_target, _, _ = next(self.loader) |
| except StopIteration: |
| self.next_input = None |
| self.next_target = None |
| return |
|
|
| with torch.cuda.stream(self.stream): |
| self.input_cuda() |
| self.next_target = self.next_target.cuda(non_blocking=True) |
|
|
| def next(self): |
| torch.cuda.current_stream().wait_stream(self.stream) |
| input = self.next_input |
| target = self.next_target |
| if input is not None: |
| self.record_stream(input) |
| if target is not None: |
| target.record_stream(torch.cuda.current_stream()) |
| self.preload() |
| return input, target |
|
|
| def _input_cuda_for_image(self): |
| self.next_input = self.next_input.cuda(non_blocking=True) |
|
|
| @staticmethod |
| def _record_stream_for_image(input): |
| input.record_stream(torch.cuda.current_stream()) |
|
|