Spaces:
Running
Running
| from typing import TYPE_CHECKING | |
| from threading import Thread, Event | |
| from queue import Queue | |
| import time | |
| import numpy as np | |
| import torch | |
| from easydict import EasyDict | |
| from ding.framework import task | |
| from ding.data import Dataset, DataLoader | |
| from ding.utils import get_rank, get_world_size | |
| if TYPE_CHECKING: | |
| from ding.framework import OfflineRLContext | |
| class OfflineMemoryDataFetcher: | |
| def __new__(cls, *args, **kwargs): | |
| if task.router.is_active and not task.has_role(task.role.FETCHER): | |
| return task.void() | |
| return super(OfflineMemoryDataFetcher, cls).__new__(cls) | |
| def __init__(self, cfg: EasyDict, dataset: Dataset): | |
| device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' | |
| if device != 'cpu': | |
| stream = torch.cuda.Stream() | |
| def producer(queue, dataset, batch_size, device, event): | |
| torch.set_num_threads(4) | |
| if device != 'cpu': | |
| nonlocal stream | |
| sbatch_size = batch_size * get_world_size() | |
| rank = get_rank() | |
| idx_list = np.random.permutation(len(dataset)) | |
| temp_idx_list = [] | |
| for i in range(len(dataset) // sbatch_size): | |
| temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size]) | |
| idx_iter = iter(temp_idx_list) | |
| if device != 'cpu': | |
| with torch.cuda.stream(stream): | |
| while True: | |
| if queue.full(): | |
| time.sleep(0.1) | |
| else: | |
| data = [] | |
| for _ in range(batch_size): | |
| try: | |
| data.append(dataset.__getitem__(next(idx_iter))) | |
| except StopIteration: | |
| del idx_iter | |
| idx_list = np.random.permutation(len(dataset)) | |
| idx_iter = iter(idx_list) | |
| data.append(dataset.__getitem__(next(idx_iter))) | |
| data = [[i[j] for i in data] for j in range(len(data[0]))] | |
| data = [torch.stack(x).to(device) for x in data] | |
| queue.put(data) | |
| if event.is_set(): | |
| break | |
| else: | |
| while True: | |
| if queue.full(): | |
| time.sleep(0.1) | |
| else: | |
| data = [] | |
| for _ in range(batch_size): | |
| try: | |
| data.append(dataset.__getitem__(next(idx_iter))) | |
| except StopIteration: | |
| del idx_iter | |
| idx_list = np.random.permutation(len(dataset)) | |
| idx_iter = iter(idx_list) | |
| data.append(dataset.__getitem__(next(idx_iter))) | |
| data = [[i[j] for i in data] for j in range(len(data[0]))] | |
| data = [torch.stack(x) for x in data] | |
| queue.put(data) | |
| if event.is_set(): | |
| break | |
| self.queue = Queue(maxsize=50) | |
| self.event = Event() | |
| self.producer_thread = Thread( | |
| target=producer, | |
| args=(self.queue, dataset, cfg.policy.batch_size, device, self.event), | |
| name='cuda_fetcher_producer' | |
| ) | |
| def __call__(self, ctx: "OfflineRLContext"): | |
| if not self.producer_thread.is_alive(): | |
| time.sleep(5) | |
| self.producer_thread.start() | |
| while self.queue.empty(): | |
| time.sleep(0.001) | |
| ctx.train_data = self.queue.get() | |
| def __del__(self): | |
| if self.producer_thread.is_alive(): | |
| self.event.set() | |
| del self.queue | |