Spaces:
Sleeping
Sleeping
| import torch | |
| from ding.interaction.slave import Slave, TaskFail | |
| class NaiveCollector(Slave): | |
| """ | |
| Overview: | |
| A slave, whose master is coordinator. | |
| Used to pass message between comm collector and coordinator. | |
| Interfaces: | |
| _process_task, _get_timestep | |
| """ | |
| def __init__(self, *args, prefix='', **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self._prefix = prefix | |
| def _process_task(self, task): | |
| """ | |
| Overview: | |
| Process a task according to input task info dict, which is passed in by master coordinator. | |
| For each type of task, you can refer to corresponding callback function in comm collector for details. | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): Task dict. Must contain key "name". | |
| Returns: | |
| - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception. | |
| """ | |
| task_name = task['name'] | |
| if task_name == 'resource': | |
| return {'cpu': '20', 'gpu': '1'} | |
| elif task_name == 'collector_start_task': | |
| self.count = 0 | |
| self.task_info = task['task_info'] | |
| return {'message': 'collector task has started'} | |
| elif task_name == 'collector_data_task': | |
| self.count += 1 | |
| data_id = './{}_{}_{}'.format(self._prefix, self.task_info['task_id'], self.count) | |
| torch.save(self._get_timestep(), data_id) | |
| data = {'data_id': data_id, 'buffer_id': self.task_info['buffer_id'], 'unroll_split_begin': 0} | |
| data['task_id'] = self.task_info['task_id'] | |
| if self.count == 20: | |
| return { | |
| 'task_id': self.task_info['task_id'], | |
| 'collector_done': True, | |
| 'cur_episode': 1, | |
| 'cur_step': 314, | |
| 'cur_sample': 314, | |
| } | |
| else: | |
| return data | |
| else: | |
| raise TaskFail( | |
| result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name) | |
| ) | |
| def _get_timestep(self): | |
| return [ | |
| { | |
| 'obs': torch.rand(4), | |
| 'next_obs': torch.randn(4), | |
| 'reward': torch.randint(0, 2, size=(3, )).float(), | |
| 'action': torch.randint(0, 2, size=(1, )), | |
| 'done': False, | |
| } | |
| ] | |