Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, List | |
| from ding.hpc_rl import hpc_wrapper | |
| def shape_fn_scatter_connection(args, kwargs) -> List[int]: | |
| """ | |
| Overview: | |
| Return the shape of scatter_connection for HPC. | |
| Arguments: | |
| - args (:obj:`Tuple`): The arguments passed to the scatter_connection function. | |
| - kwargs (:obj:`Dict`): The keyword arguments passed to the scatter_connection function. | |
| Returns: | |
| - shape (:obj:`List[int]`): A list representing the shape of scatter_connection, \ | |
| in the form of [B, M, N, H, W, scatter_type]. | |
| """ | |
| if len(args) <= 1: | |
| tmp = list(kwargs['x'].shape) | |
| else: | |
| tmp = list(args[1].shape) # args[0] is __main__.ScatterConnection object | |
| if len(args) <= 2: | |
| tmp.extend(kwargs['spatial_size']) | |
| else: | |
| tmp.extend(args[2]) | |
| tmp.append(args[0].scatter_type) | |
| return tmp | |
| class ScatterConnection(nn.Module): | |
| """ | |
| Overview: | |
| Scatter feature to its corresponding location. In AlphaStar, each entity is embedded into a tensor, | |
| and these tensors are scattered into a feature map with map size. | |
| Interfaces: | |
| ``__init__``, ``forward``, ``xy_forward`` | |
| """ | |
| def __init__(self, scatter_type: str) -> None: | |
| """ | |
| Overview: | |
| Initialize the ScatterConnection object. | |
| Arguments: | |
| - scatter_type (:obj:`str`): The scatter type, which decides the behavior when two entities have the \ | |
| same location. It can be either 'add' or 'cover'. If 'add', the first one will be added to the \ | |
| second one. If 'cover', the first one will be covered by the second one. | |
| """ | |
| super(ScatterConnection, self).__init__() | |
| self.scatter_type = scatter_type | |
| assert self.scatter_type in ['cover', 'add'] | |
| def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Scatter input tensor 'x' into a spatial feature map. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \ | |
| is the number of entities, and `N` is the dimension of entity attributes. | |
| - spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \ | |
| will be scattered, where `H` is the height and `W` is the width. | |
| - location (:obj:`torch.Tensor`): The tensor of locations of shape `(B, M, 2)`. \ | |
| Each location should be (y, x). | |
| Returns: | |
| - output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`. | |
| Note: | |
| When there are some overlapping in locations, 'cover' mode will result in the loss of information. | |
| 'add' mode is used as a temporary substitute. | |
| """ | |
| device = x.device | |
| B, M, N = x.shape | |
| x = x.permute(0, 2, 1) | |
| H, W = spatial_size | |
| index = location[:, :, 1] + location[:, :, 0] * W | |
| index = index.unsqueeze(dim=1).repeat(1, N, 1) | |
| output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W) | |
| if self.scatter_type == 'cover': | |
| output.scatter_(dim=2, index=index, src=x) | |
| elif self.scatter_type == 'add': | |
| output.scatter_add_(dim=2, index=index, src=x) | |
| output = output.view(B, N, H, W) | |
| return output | |
| def xy_forward( | |
| self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor, coord_y | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Scatter input tensor 'x' into a spatial feature map using separate x and y coordinates. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \ | |
| is the number of entities, and `N` is the dimension of entity attributes. | |
| - spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \ | |
| will be scattered, where `H` is the height and `W` is the width. | |
| - coord_x (:obj:`torch.Tensor`): The x-coordinates tensor of shape `(B, M)`. | |
| - coord_y (:obj:`torch.Tensor`): The y-coordinates tensor of shape `(B, M)`. | |
| Returns: | |
| - output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`. | |
| Note: | |
| When there are some overlapping in locations, 'cover' mode will result in the loss of information. | |
| 'add' mode is used as a temporary substitute. | |
| """ | |
| device = x.device | |
| B, M, N = x.shape | |
| x = x.permute(0, 2, 1) | |
| H, W = spatial_size | |
| index = (coord_x * W + coord_y).long() | |
| index = index.unsqueeze(dim=1).repeat(1, N, 1) | |
| output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W) | |
| if self.scatter_type == 'cover': | |
| output.scatter_(dim=2, index=index, src=x) | |
| elif self.scatter_type == 'add': | |
| output.scatter_add_(dim=2, index=index, src=x) | |
| output = output.view(B, N, H, W) | |
| return output | |