PengLiu
push inference code
56ef371
raw
history blame
14.3 kB
import copy
import math
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from detect_tools.upn import POS_EMBEDDINGS
from detect_tools.upn.models.module import NestedTensor
@POS_EMBEDDINGS.register_module()
class PositionEmbeddingSine(nn.Module):
"""This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
Args:
num_pos_feats (int): The channel of positional embeddings.
temperature (float): The temperature used in positional embeddings.
normalize (bool): Whether to normalize the positional embeddings.
scale (float): The scale factor of positional embeddings.
"""
def __init__(
self,
num_pos_feats: int = 64,
temperature: int = 10000,
normalize: bool = False,
scale: float = None,
) -> None:
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
"""Forward function.
Args:
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
Returns:
torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
"""
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
@POS_EMBEDDINGS.register_module()
class PositionEmbeddingSineHW(nn.Module):
"""This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
Args:
num_pos_feats (int): The channel of positional embeddings.
temperatureH (float): The temperature used in positional embeddings.
temperatureW (float): The temperature used in positional embeddings.
normalize (bool): Whether to normalize the positional embeddings.
scale (float): The scale factor of positional embeddings.
"""
def __init__(
self,
num_pos_feats: int = 64,
temperatureH: int = 10000,
temperatureW: int = 10000,
normalize: bool = False,
scale: float = None,
) -> None:
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperatureH = temperatureH
self.temperatureW = temperatureW
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
"""Forward function.
Args:
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
Returns:
torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
"""
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
@POS_EMBEDDINGS.register_module()
class PositionEmbeddingLearned(nn.Module):
"""Absolute pos embedding, learned.
Args:
num_pos_feats (int): The channel dimension of positional embeddings.
num_row (int): The number of rows of the input feature map.
num_col (int): The number of columns of the input feature map.
"""
def __init__(
self, num_row: int = 50, num_col: int = 50, num_pos_feats: int = 256
) -> None:
super().__init__()
self.row_embed = nn.Embedding(num_row, num_pos_feats)
self.col_embed = nn.Embedding(num_col, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
"""Forward function.
Args:
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
Returns:
torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
"""
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
torch.cat(
[
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(x.shape[0], 1, 1, 1)
)
return pos
def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ("v2", "sine"):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSineHW(
N_steps,
temperatureH=args.pe_temperatureH,
temperatureW=args.pe_temperatureW,
normalize=True,
)
elif args.position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding
def clean_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == "module.":
k = k[7:] # remove `module.`
new_state_dict[k] = v
return new_state_dict
def get_activation_fn(activation: str, d_model: int = 256, batch_dim: int = 0):
"""Return an activation function given a string
Args:
activation (str): activation function name
d_model (int, optional): d_model. Defaults to 256.
batch_dim (int, optional): batch dimension. Defaults to 0.
Returns:
F: activation function
"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
if activation == "prelu":
return nn.PReLU()
if activation == "selu":
return F.selu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def get_clones(module: nn.Module, N: int, layer_share: bool = False):
"""Copy module N times
Args:
module (nn.Module): module to copy
N (int): number of copies
layer_share (bool, optional): share the same layer. If true, the modules will
share the same memory. Defaults to False.
"""
if layer_share:
return nn.ModuleList([module for _ in range(N)])
else:
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
def inverse_sigmoid(x, eps=1e-3):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
def gen_sineembed_for_position(pos_tensor):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
).flatten(2)
pos_y = torch.stack(
(pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack(
(pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack(
(pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
return pos
def get_sine_pos_embed(
pos_tensor: torch.Tensor,
num_pos_feats: int = 128,
temperature: int = 10000,
exchange_xy: bool = True,
):
"""generate sine position embedding from a position tensor
Args:
pos_tensor (torch.Tensor): shape: [..., n].
num_pos_feats (int): projected shape for each float in the tensor.
temperature (int): temperature in the sine/cosine function.
exchange_xy (bool, optional): exchange pos x and pos y. \
For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
Returns:
pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
"""
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
dim_t = temperature ** (
2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats
)
def sine_func(x: torch.Tensor):
sin_x = x * scale / dim_t
sin_x = torch.stack(
(sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3
).flatten(2)
return sin_x
pos_res = [
sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)
]
if exchange_xy:
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
pos_res = torch.cat(pos_res, dim=-1)
return pos_res
def gen_encoder_output_proposals(
memory: torch.Tensor,
memory_padding_mask: torch.Tensor,
spatial_shapes: torch.Tensor,
learnedwh=None,
):
"""
Input:
- memory: bs, \sum{hw}, d_model
- memory_padding_mask: bs, \sum{hw}
- spatial_shapes: nlevel, 2
- learnedwh: 2
Output:
- output_memory: bs, \sum{hw}, d_model
- output_proposals: bs, \sum{hw}, 4
"""
N_, S_, C_ = memory.shape
base_scale = 4.0
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(
N_, H_, W_, 1
)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(
N_, 1, 1, 2
)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
if learnedwh is not None:
wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
else:
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
_cur += H_ * W_
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = (
(output_proposals > 0.01) & (output_proposals < 0.99)
).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
output_proposals = output_proposals.masked_fill(
memory_padding_mask.unsqueeze(-1), float("inf")
)
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float("inf")
)
output_memory = memory
output_memory = output_memory.masked_fill(
memory_padding_mask.unsqueeze(-1), float(0)
)
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
return output_memory, output_proposals