Spaces:
Runtime error
Runtime error
| import copy | |
| import math | |
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from detect_tools.upn import POS_EMBEDDINGS | |
| from detect_tools.upn.models.module import NestedTensor | |
| class PositionEmbeddingSine(nn.Module): | |
| """This is a more standard version of the position embedding, very similar to the one | |
| used by the Attention is all you need paper, generalized to work on images. | |
| Args: | |
| num_pos_feats (int): The channel of positional embeddings. | |
| temperature (float): The temperature used in positional embeddings. | |
| normalize (bool): Whether to normalize the positional embeddings. | |
| scale (float): The scale factor of positional embeddings. | |
| """ | |
| def __init__( | |
| self, | |
| num_pos_feats: int = 64, | |
| temperature: int = 10000, | |
| normalize: bool = False, | |
| scale: float = None, | |
| ) -> None: | |
| super().__init__() | |
| self.num_pos_feats = num_pos_feats | |
| self.temperature = temperature | |
| self.normalize = normalize | |
| if scale is not None and normalize is False: | |
| raise ValueError("normalize should be True if scale is passed") | |
| if scale is None: | |
| scale = 2 * math.pi | |
| self.scale = scale | |
| def forward(self, tensor_list: NestedTensor) -> torch.Tensor: | |
| """Forward function. | |
| Args: | |
| tensor_list (NestedTensor): NestedTensor wrapping the input tensor. | |
| Returns: | |
| torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W) | |
| """ | |
| x = tensor_list.tensors | |
| mask = tensor_list.mask | |
| assert mask is not None | |
| not_mask = ~mask | |
| y_embed = not_mask.cumsum(1, dtype=torch.float32) | |
| x_embed = not_mask.cumsum(2, dtype=torch.float32) | |
| if self.normalize: | |
| eps = 1e-6 | |
| y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale | |
| x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale | |
| dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) | |
| dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) | |
| pos_x = x_embed[:, :, :, None] / dim_t | |
| pos_y = y_embed[:, :, :, None] / dim_t | |
| pos_x = torch.stack( | |
| (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 | |
| ).flatten(3) | |
| pos_y = torch.stack( | |
| (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 | |
| ).flatten(3) | |
| pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) | |
| return pos | |
| class PositionEmbeddingSineHW(nn.Module): | |
| """This is a more standard version of the position embedding, very similar to the one | |
| used by the Attention is all you need paper, generalized to work on images. | |
| Args: | |
| num_pos_feats (int): The channel of positional embeddings. | |
| temperatureH (float): The temperature used in positional embeddings. | |
| temperatureW (float): The temperature used in positional embeddings. | |
| normalize (bool): Whether to normalize the positional embeddings. | |
| scale (float): The scale factor of positional embeddings. | |
| """ | |
| def __init__( | |
| self, | |
| num_pos_feats: int = 64, | |
| temperatureH: int = 10000, | |
| temperatureW: int = 10000, | |
| normalize: bool = False, | |
| scale: float = None, | |
| ) -> None: | |
| super().__init__() | |
| self.num_pos_feats = num_pos_feats | |
| self.temperatureH = temperatureH | |
| self.temperatureW = temperatureW | |
| self.normalize = normalize | |
| if scale is not None and normalize is False: | |
| raise ValueError("normalize should be True if scale is passed") | |
| if scale is None: | |
| scale = 2 * math.pi | |
| self.scale = scale | |
| def forward(self, tensor_list: NestedTensor) -> torch.Tensor: | |
| """Forward function. | |
| Args: | |
| tensor_list (NestedTensor): NestedTensor wrapping the input tensor. | |
| Returns: | |
| torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W) | |
| """ | |
| x = tensor_list.tensors | |
| mask = tensor_list.mask | |
| assert mask is not None | |
| not_mask = ~mask | |
| y_embed = not_mask.cumsum(1, dtype=torch.float32) | |
| x_embed = not_mask.cumsum(2, dtype=torch.float32) | |
| if self.normalize: | |
| eps = 1e-6 | |
| y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale | |
| x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale | |
| dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) | |
| dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats) | |
| pos_x = x_embed[:, :, :, None] / dim_tx | |
| dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) | |
| dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats) | |
| pos_y = y_embed[:, :, :, None] / dim_ty | |
| pos_x = torch.stack( | |
| (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 | |
| ).flatten(3) | |
| pos_y = torch.stack( | |
| (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 | |
| ).flatten(3) | |
| pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) | |
| return pos | |
| class PositionEmbeddingLearned(nn.Module): | |
| """Absolute pos embedding, learned. | |
| Args: | |
| num_pos_feats (int): The channel dimension of positional embeddings. | |
| num_row (int): The number of rows of the input feature map. | |
| num_col (int): The number of columns of the input feature map. | |
| """ | |
| def __init__( | |
| self, num_row: int = 50, num_col: int = 50, num_pos_feats: int = 256 | |
| ) -> None: | |
| super().__init__() | |
| self.row_embed = nn.Embedding(num_row, num_pos_feats) | |
| self.col_embed = nn.Embedding(num_col, num_pos_feats) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.init.uniform_(self.row_embed.weight) | |
| nn.init.uniform_(self.col_embed.weight) | |
| def forward(self, tensor_list: NestedTensor) -> torch.Tensor: | |
| """Forward function. | |
| Args: | |
| tensor_list (NestedTensor): NestedTensor wrapping the input tensor. | |
| Returns: | |
| torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W) | |
| """ | |
| x = tensor_list.tensors | |
| h, w = x.shape[-2:] | |
| i = torch.arange(w, device=x.device) | |
| j = torch.arange(h, device=x.device) | |
| x_emb = self.col_embed(i) | |
| y_emb = self.row_embed(j) | |
| pos = ( | |
| torch.cat( | |
| [ | |
| x_emb.unsqueeze(0).repeat(h, 1, 1), | |
| y_emb.unsqueeze(1).repeat(1, w, 1), | |
| ], | |
| dim=-1, | |
| ) | |
| .permute(2, 0, 1) | |
| .unsqueeze(0) | |
| .repeat(x.shape[0], 1, 1, 1) | |
| ) | |
| return pos | |
| def build_position_encoding(args): | |
| N_steps = args.hidden_dim // 2 | |
| if args.position_embedding in ("v2", "sine"): | |
| # TODO find a better way of exposing other arguments | |
| position_embedding = PositionEmbeddingSineHW( | |
| N_steps, | |
| temperatureH=args.pe_temperatureH, | |
| temperatureW=args.pe_temperatureW, | |
| normalize=True, | |
| ) | |
| elif args.position_embedding in ("v3", "learned"): | |
| position_embedding = PositionEmbeddingLearned(N_steps) | |
| else: | |
| raise ValueError(f"not supported {args.position_embedding}") | |
| return position_embedding | |
| def clean_state_dict(state_dict): | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| if k[:7] == "module.": | |
| k = k[7:] # remove `module.` | |
| new_state_dict[k] = v | |
| return new_state_dict | |
| def get_activation_fn(activation: str, d_model: int = 256, batch_dim: int = 0): | |
| """Return an activation function given a string | |
| Args: | |
| activation (str): activation function name | |
| d_model (int, optional): d_model. Defaults to 256. | |
| batch_dim (int, optional): batch dimension. Defaults to 0. | |
| Returns: | |
| F: activation function | |
| """ | |
| if activation == "relu": | |
| return F.relu | |
| if activation == "gelu": | |
| return F.gelu | |
| if activation == "glu": | |
| return F.glu | |
| if activation == "prelu": | |
| return nn.PReLU() | |
| if activation == "selu": | |
| return F.selu | |
| raise RuntimeError(f"activation should be relu/gelu, not {activation}.") | |
| def get_clones(module: nn.Module, N: int, layer_share: bool = False): | |
| """Copy module N times | |
| Args: | |
| module (nn.Module): module to copy | |
| N (int): number of copies | |
| layer_share (bool, optional): share the same layer. If true, the modules will | |
| share the same memory. Defaults to False. | |
| """ | |
| if layer_share: | |
| return nn.ModuleList([module for _ in range(N)]) | |
| else: | |
| return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) | |
| def inverse_sigmoid(x, eps=1e-3): | |
| x = x.clamp(min=0, max=1) | |
| x1 = x.clamp(min=eps) | |
| x2 = (1 - x).clamp(min=eps) | |
| return torch.log(x1 / x2) | |
| def gen_sineembed_for_position(pos_tensor): | |
| # n_query, bs, _ = pos_tensor.size() | |
| # sineembed_tensor = torch.zeros(n_query, bs, 256) | |
| scale = 2 * math.pi | |
| dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) | |
| dim_t = 10000 ** (2 * (dim_t // 2) / 128) | |
| x_embed = pos_tensor[:, :, 0] * scale | |
| y_embed = pos_tensor[:, :, 1] * scale | |
| pos_x = x_embed[:, :, None] / dim_t | |
| pos_y = y_embed[:, :, None] / dim_t | |
| pos_x = torch.stack( | |
| (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 | |
| ).flatten(2) | |
| pos_y = torch.stack( | |
| (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 | |
| ).flatten(2) | |
| if pos_tensor.size(-1) == 2: | |
| pos = torch.cat((pos_y, pos_x), dim=2) | |
| elif pos_tensor.size(-1) == 4: | |
| w_embed = pos_tensor[:, :, 2] * scale | |
| pos_w = w_embed[:, :, None] / dim_t | |
| pos_w = torch.stack( | |
| (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 | |
| ).flatten(2) | |
| h_embed = pos_tensor[:, :, 3] * scale | |
| pos_h = h_embed[:, :, None] / dim_t | |
| pos_h = torch.stack( | |
| (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 | |
| ).flatten(2) | |
| pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) | |
| else: | |
| raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) | |
| return pos | |
| def get_sine_pos_embed( | |
| pos_tensor: torch.Tensor, | |
| num_pos_feats: int = 128, | |
| temperature: int = 10000, | |
| exchange_xy: bool = True, | |
| ): | |
| """generate sine position embedding from a position tensor | |
| Args: | |
| pos_tensor (torch.Tensor): shape: [..., n]. | |
| num_pos_feats (int): projected shape for each float in the tensor. | |
| temperature (int): temperature in the sine/cosine function. | |
| exchange_xy (bool, optional): exchange pos x and pos y. \ | |
| For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True. | |
| Returns: | |
| pos_embed (torch.Tensor): shape: [..., n*num_pos_feats]. | |
| """ | |
| scale = 2 * math.pi | |
| dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device) | |
| dim_t = temperature ** ( | |
| 2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats | |
| ) | |
| def sine_func(x: torch.Tensor): | |
| sin_x = x * scale / dim_t | |
| sin_x = torch.stack( | |
| (sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3 | |
| ).flatten(2) | |
| return sin_x | |
| pos_res = [ | |
| sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1) | |
| ] | |
| if exchange_xy: | |
| pos_res[0], pos_res[1] = pos_res[1], pos_res[0] | |
| pos_res = torch.cat(pos_res, dim=-1) | |
| return pos_res | |
| def gen_encoder_output_proposals( | |
| memory: torch.Tensor, | |
| memory_padding_mask: torch.Tensor, | |
| spatial_shapes: torch.Tensor, | |
| learnedwh=None, | |
| ): | |
| """ | |
| Input: | |
| - memory: bs, \sum{hw}, d_model | |
| - memory_padding_mask: bs, \sum{hw} | |
| - spatial_shapes: nlevel, 2 | |
| - learnedwh: 2 | |
| Output: | |
| - output_memory: bs, \sum{hw}, d_model | |
| - output_proposals: bs, \sum{hw}, 4 | |
| """ | |
| N_, S_, C_ = memory.shape | |
| base_scale = 4.0 | |
| proposals = [] | |
| _cur = 0 | |
| for lvl, (H_, W_) in enumerate(spatial_shapes): | |
| mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view( | |
| N_, H_, W_, 1 | |
| ) | |
| valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) | |
| valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) | |
| grid_y, grid_x = torch.meshgrid( | |
| torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), | |
| torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device), | |
| ) | |
| grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 | |
| scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view( | |
| N_, 1, 1, 2 | |
| ) | |
| grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale | |
| if learnedwh is not None: | |
| wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl) | |
| else: | |
| wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) | |
| proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) | |
| proposals.append(proposal) | |
| _cur += H_ * W_ | |
| output_proposals = torch.cat(proposals, 1) | |
| output_proposals_valid = ( | |
| (output_proposals > 0.01) & (output_proposals < 0.99) | |
| ).all(-1, keepdim=True) | |
| output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid | |
| output_proposals = output_proposals.masked_fill( | |
| memory_padding_mask.unsqueeze(-1), float("inf") | |
| ) | |
| output_proposals = output_proposals.masked_fill( | |
| ~output_proposals_valid, float("inf") | |
| ) | |
| output_memory = memory | |
| output_memory = output_memory.masked_fill( | |
| memory_padding_mask.unsqueeze(-1), float(0) | |
| ) | |
| output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) | |
| return output_memory, output_proposals | |