Spaces:
Running
Running
| import copy | |
| from typing import List, Tuple | |
| import numpy as np | |
| from easydict import EasyDict | |
| from ding.utils.compression_helper import jpeg_data_decompressor | |
| class GameSegment: | |
| """ | |
| Overview: | |
| A game segment from a full episode trajectory. | |
| The length of one episode in (Atari) games is often quite large. This class represents a single game segment | |
| within a larger trajectory, split into several blocks. | |
| Interfaces: | |
| - __init__ | |
| - __len__ | |
| - reset | |
| - pad_over | |
| - is_full | |
| - legal_actions | |
| - append | |
| - get_observation | |
| - zero_obs | |
| - step_obs | |
| - get_targets | |
| - game_segment_to_array | |
| - store_search_stats | |
| """ | |
| def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None: | |
| """ | |
| Overview: | |
| Init the ``GameSegment`` according to the provided arguments. | |
| Arguments: | |
| action_space (:obj:`int`): action space | |
| - game_segment_length (:obj:`int`): the transition number of one ``GameSegment`` block | |
| """ | |
| self.action_space = action_space | |
| self.game_segment_length = game_segment_length | |
| self.num_unroll_steps = config.num_unroll_steps | |
| self.td_steps = config.td_steps | |
| self.frame_stack_num = config.model.frame_stack_num | |
| self.discount_factor = config.discount_factor | |
| self.action_space_size = config.model.action_space_size | |
| self.gray_scale = config.gray_scale | |
| self.transform2string = config.transform2string | |
| self.sampled_algo = config.sampled_algo | |
| self.gumbel_algo = config.gumbel_algo | |
| self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder | |
| if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: | |
| # for vector obs input, e.g. classical control and box2d environments | |
| self.zero_obs_shape = config.model.observation_shape | |
| elif len(config.model.observation_shape) == 3: | |
| # image obs input, e.g. atari environments | |
| self.zero_obs_shape = ( | |
| config.model.observation_shape[-2], config.model.observation_shape[-1], config.model.image_channel | |
| ) | |
| self.obs_segment = [] | |
| self.action_segment = [] | |
| self.reward_segment = [] | |
| self.child_visit_segment = [] | |
| self.root_value_segment = [] | |
| self.action_mask_segment = [] | |
| self.to_play_segment = [] | |
| self.target_values = [] | |
| self.target_rewards = [] | |
| self.target_policies = [] | |
| self.improved_policy_probs = [] | |
| if self.sampled_algo: | |
| self.root_sampled_actions = [] | |
| if self.use_ture_chance_label_in_chance_encoder: | |
| self.chance_segment = [] | |
| def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: | |
| """ | |
| Overview: | |
| Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps]. | |
| Arguments: | |
| - timestep (int): The time step. | |
| - num_unroll_steps (int): The extra length of the observation frames. | |
| - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory. | |
| """ | |
| stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps] | |
| if padding: | |
| pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs) | |
| if pad_len > 0: | |
| pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)]) | |
| stacked_obs = np.concatenate((stacked_obs, pad_frames)) | |
| if self.transform2string: | |
| stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] | |
| return stacked_obs | |
| def zero_obs(self) -> List: | |
| """ | |
| Overview: | |
| Return an observation frame filled with zeros. | |
| Returns: | |
| ndarray: An array filled with zeros. | |
| """ | |
| return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)] | |
| def get_obs(self) -> List: | |
| """ | |
| Overview: | |
| Return an observation in the correct format for model inference. | |
| Returns: | |
| stacked_obs (List): An observation in the correct format for model inference. | |
| """ | |
| timestep_obs = len(self.obs_segment) - self.frame_stack_num | |
| timestep_reward = len(self.reward_segment) | |
| assert timestep_obs == timestep_reward, "timestep_obs: {}, timestep_reward: {}".format( | |
| timestep_obs, timestep_reward | |
| ) | |
| timestep = timestep_reward | |
| stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num] | |
| if self.transform2string: | |
| stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] | |
| return stacked_obs | |
| def append( | |
| self, | |
| action: np.ndarray, | |
| obs: np.ndarray, | |
| reward: np.ndarray, | |
| action_mask: np.ndarray = None, | |
| to_play: int = -1, | |
| chance: int = 0, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t}. | |
| """ | |
| self.action_segment.append(action) | |
| self.obs_segment.append(obs) | |
| self.reward_segment.append(reward) | |
| self.action_mask_segment.append(action_mask) | |
| self.to_play_segment.append(to_play) | |
| if self.use_ture_chance_label_in_chance_encoder: | |
| self.chance_segment.append(chance) | |
| def pad_over( | |
| self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List, | |
| next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None, | |
| ) -> None: | |
| """ | |
| Overview: | |
| To make sure the correction of value targets, we need to add (o_t, r_t, etc) from the next game_segment | |
| , which is necessary for the bootstrapped values at the end states of previous game_segment. | |
| e.g: len = 100; target value v_100 = r_100 + gamma^1 r_101 + ... + gamma^4 r_104 + gamma^5 v_105, | |
| but r_101, r_102, ... are from the next game_segment. | |
| Arguments: | |
| - next_segment_observations (:obj:`list`): o_t from the next game_segment | |
| - next_segment_rewards (:obj:`list`): r_t from the next game_segment | |
| - next_segment_root_values (:obj:`list`): root values of MCTS from the next game_segment | |
| - next_segment_child_visits (:obj:`list`): root visit count distributions of MCTS from the next game_segment | |
| - next_segment_improved_policy (:obj:`list`): root children select policy of MCTS from the next game_segment (Only used in Gumbel MuZero) | |
| """ | |
| assert len(next_segment_observations) <= self.num_unroll_steps | |
| assert len(next_segment_child_visits) <= self.num_unroll_steps | |
| assert len(next_segment_root_values) <= self.num_unroll_steps + self.td_steps | |
| assert len(next_segment_rewards) <= self.num_unroll_steps + self.td_steps - 1 | |
| # ============================================================== | |
| # The core difference between GumbelMuZero and MuZero | |
| # ============================================================== | |
| if self.gumbel_algo: | |
| assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.td_steps | |
| # NOTE: next block observation should start from (stacked_observation - 1) in next trajectory | |
| for observation in next_segment_observations: | |
| self.obs_segment.append(copy.deepcopy(observation)) | |
| for reward in next_segment_rewards: | |
| self.reward_segment.append(reward) | |
| for value in next_segment_root_values: | |
| self.root_value_segment.append(value) | |
| for child_visits in next_segment_child_visits: | |
| self.child_visit_segment.append(child_visits) | |
| if self.gumbel_algo: | |
| for improved_policy in next_segment_improved_policy: | |
| self.improved_policy_probs.append(improved_policy) | |
| if self.use_ture_chance_label_in_chance_encoder: | |
| for chances in next_chances: | |
| self.chance_segment.append(chances) | |
| def get_targets(self, timestep: int) -> Tuple: | |
| """ | |
| Overview: | |
| return the value/reward/policy targets at step timestep | |
| """ | |
| return self.target_values[timestep], self.target_rewards[timestep], self.target_policies[timestep] | |
| def store_search_stats( | |
| self, visit_counts: List, root_value: List, root_sampled_actions: List = None, improved_policy: List = None, idx: int = None | |
| ) -> None: | |
| """ | |
| Overview: | |
| store the visit count distributions and value of the root node after MCTS. | |
| """ | |
| sum_visits = sum(visit_counts) | |
| if idx is None: | |
| self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts]) | |
| self.root_value_segment.append(root_value) | |
| if self.sampled_algo: | |
| self.root_sampled_actions.append(root_sampled_actions) | |
| # store the improved policy in Gumbel Muzero: \pi'=softmax(logits + \sigma(CompletedQ)) | |
| if self.gumbel_algo: | |
| self.improved_policy_probs.append(improved_policy) | |
| else: | |
| self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts] | |
| self.root_value_segment[idx] = root_value | |
| self.improved_policy_probs[idx] = improved_policy | |
| def game_segment_to_array(self) -> None: | |
| """ | |
| Overview: | |
| Post-process the data when a `GameSegment` block is full. This function converts various game segment | |
| elements into numpy arrays for easier manipulation and processing. | |
| Structure: | |
| The structure and shapes of different game segment elements are as follows. Let's assume | |
| `game_segment_length`=20, `stack`=4, `num_unroll_steps`=5, `td_steps`=5: | |
| - obs: game_segment_length + stack + num_unroll_steps, 20+4+5 | |
| - action: game_segment_length -> 20 | |
| - reward: game_segment_length + num_unroll_steps + td_steps -1 20+5+5-1 | |
| - root_values: game_segment_length + num_unroll_steps + td_steps -> 20+5+5 | |
| - child_visits: game_segment_length + num_unroll_steps -> 20+5 | |
| - to_play: game_segment_length -> 20 | |
| - action_mask: game_segment_length -> 20 | |
| Examples: | |
| Here is an illustration of the structure of `obs` and `rew` for two consecutive game segments | |
| (game_segment_i and game_segment_i+1): | |
| - game_segment_i (obs): 4 20 5 | |
| ----|----...----|-----| | |
| - game_segment_i+1 (obs): 4 20 5 | |
| ----|----...----|-----| | |
| - game_segment_i (rew): 20 5 4 | |
| ----...----|------|-----| | |
| - game_segment_i+1 (rew): 20 5 4 | |
| ----...----|------|-----| | |
| Postprocessing: | |
| - self.obs_segment (:obj:`numpy.ndarray`): A numpy array version of the original obs_segment. | |
| - self.action_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_segment. | |
| - self.reward_segment (:obj:`numpy.ndarray`): A numpy array version of the original reward_segment. | |
| - self.child_visit_segment (:obj:`numpy.ndarray`): A numpy array version of the original child_visit_segment. | |
| - self.root_value_segment (:obj:`numpy.ndarray`): A numpy array version of the original root_value_segment. | |
| - self.improved_policy_probs (:obj:`numpy.ndarray`): A numpy array version of the original improved_policy_probs. | |
| - self.action_mask_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_mask_segment. | |
| - self.to_play_segment (:obj:`numpy.ndarray`): A numpy array version of the original to_play_segment. | |
| - self.chance_segment (:obj:`numpy.ndarray`, optional): A numpy array version of the original chance_segment. Only | |
| created if `self.use_ture_chance_label_in_chance_encoder` is True. | |
| .. note:: | |
| For environments with a variable action space, such as board games, the elements in `child_visit_segment` may have | |
| different lengths. In such scenarios, it is necessary to use the object data type for `self.child_visit_segment`. | |
| """ | |
| self.obs_segment = np.array(self.obs_segment) | |
| self.action_segment = np.array(self.action_segment) | |
| self.reward_segment = np.array(self.reward_segment) | |
| # Check if all elements in self.child_visit_segment have the same length | |
| if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment): | |
| self.child_visit_segment = np.array(self.child_visit_segment) | |
| else: | |
| # In the case of environments with a variable action space, such as board games, | |
| # the elements in child_visit_segment may have different lengths. | |
| # In such scenarios, it is necessary to use the object data type. | |
| self.child_visit_segment = np.array(self.child_visit_segment, dtype=object) | |
| self.root_value_segment = np.array(self.root_value_segment) | |
| self.improved_policy_probs = np.array(self.improved_policy_probs) | |
| self.action_mask_segment = np.array(self.action_mask_segment) | |
| self.to_play_segment = np.array(self.to_play_segment) | |
| if self.use_ture_chance_label_in_chance_encoder: | |
| self.chance_segment = np.array(self.chance_segment) | |
| def reset(self, init_observations: np.ndarray) -> None: | |
| """ | |
| Overview: | |
| Initialize the game segment using ``init_observations``, | |
| which is the previous ``frame_stack_num`` stacked frames. | |
| Arguments: | |
| - init_observations (:obj:`list`): list of the stack observations in the previous time steps. | |
| """ | |
| self.obs_segment = [] | |
| self.action_segment = [] | |
| self.reward_segment = [] | |
| self.child_visit_segment = [] | |
| self.root_value_segment = [] | |
| self.action_mask_segment = [] | |
| self.to_play_segment = [] | |
| if self.use_ture_chance_label_in_chance_encoder: | |
| self.chance_segment = [] | |
| assert len(init_observations) == self.frame_stack_num | |
| for observation in init_observations: | |
| self.obs_segment.append(copy.deepcopy(observation)) | |
| def is_full(self) -> bool: | |
| """ | |
| Overview: | |
| Check whether the current game segment is full, i.e. larger than the segment length. | |
| Returns: | |
| bool: True if the game segment is full, False otherwise. | |
| """ | |
| return len(self.action_segment) >= self.game_segment_length | |
| def legal_actions(self): | |
| return [_ for _ in range(self.action_space.n)] | |
| def __len__(self): | |
| return len(self.action_segment) | |