File size: 8,718 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import BaseModule
from mmengine.utils import to_3tuple
class AdaptivePadding(nn.Module):
"""Applies padding adaptively to the input.
This module can make input get fully covered by filter
you specified. It support two modes "same" and "corner". The
"same" mode is same with "SAME" padding mode in TensorFlow, pad
zero around input. The "corner" mode would pad zero
to bottom right.
Args:
kernel_size (int | tuple): Size of the kernel. Default: 1.
stride (int | tuple): Stride of the filter. Default: 1.
dilation (int | tuple): Spacing between kernel elements.
Default: 1.
padding (str): Support "same" and "corner", "corner" mode
would pad zero to bottom right, and "same" mode would
pad zero around input. Default: "corner".
Example:
>>> kernel_size = 16
>>> stride = 16
>>> dilation = 1
>>> input = torch.rand(1, 1, 15, 17)
>>> adap_pad = AdaptivePadding(
>>> kernel_size=kernel_size,
>>> stride=stride,
>>> dilation=dilation,
>>> padding="corner")
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
>>> input = torch.rand(1, 1, 16, 17)
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
"""
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
super().__init__()
assert padding in ('same', 'corner')
kernel_size = to_3tuple(kernel_size)
stride = to_3tuple(stride)
dilation = to_3tuple(dilation)
self.padding = padding
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
def get_pad_shape(self, input_shape):
"""Calculate the padding size of input.
Args:
input_shape (:obj:`torch.Size`): arrange as (H, W).
Returns:
Tuple[int]: The padding size along the
original H and W directions
"""
input_t, input_h, input_w = input_shape
kernel_d, kernel_h, kernel_w = self.kernel_size
stride_d, stride_h, stride_w = self.stride
output_d = math.ceil(input_t / stride_d)
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_d = max((output_d - 1) * stride_d +
(kernel_d - 1) * self.dilation[0] + 1 - input_t, 0)
pad_h = max((output_h - 1) * stride_h +
(kernel_h - 1) * self.dilation[1] + 1 - input_h, 0)
pad_w = max((output_w - 1) * stride_w +
(kernel_w - 1) * self.dilation[2] + 1 - input_w, 0)
return pad_d, pad_h, pad_w
def forward(self, x):
"""Add padding to `x`
Args:
x (Tensor): Input tensor has shape (B, C, H, W).
Returns:
Tensor: The tensor with adaptive padding
"""
pad_d, pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
if pad_d > 0 or pad_h > 0 or pad_w > 0:
if self.padding == 'corner':
x = F.pad(x, [0, pad_w, 0, pad_h, 0, pad_d])
elif self.padding == 'same':
x = F.pad(x, [
pad_w // 2,
pad_w - pad_w // 2,
pad_h // 2,
pad_h - pad_h // 2,
pad_d // 2,
pad_d - pad_d // 2,
])
return x
class PatchEmbed3D(BaseModule):
"""Video to Patch Embedding.
We use a conv layer to implement PatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (str): The type of convolution
to generate patch embedding. Default: "Conv3d".
kernel_size (int): The kernel_size of embedding conv.
Default: (2, 4, 4).
stride (int): The slide stride of embedding conv.
Default: (2, 4, 4).
padding (int | tuple | string): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int): The dilation rate of embedding conv. Default: 1.
bias (bool): Bias of embed conv. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
input_size (int | tuple | None): The size of input, which will be
used to calculate the out size. Only works when `dynamic_size`
is False. Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=768,
conv_type='Conv3d',
kernel_size=(2, 4, 4),
stride=(2, 4, 4),
padding='corner',
dilation=1,
bias=True,
norm_cfg=None,
input_size=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
if stride is None:
stride = kernel_size
kernel_size = to_3tuple(kernel_size)
stride = to_3tuple(stride)
dilation = to_3tuple(dilation)
if isinstance(padding, str):
self.adaptive_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of conv
padding = 0
else:
self.adaptive_padding = None
padding = to_3tuple(padding)
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
out_channels=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
if input_size:
input_size = to_3tuple(input_size)
# `init_out_size` would be used outside to
# calculate the num_patches
# e.g. when `use_abs_pos_embed` outside
self.init_input_size = input_size
if self.adaptive_padding:
pad_d, pad_h, pad_w = self.adaptive_padding.get_pad_shape(
input_size)
input_t, input_h, input_w = input_size
input_t = input_t + pad_d
input_h = input_h + pad_h
input_w = input_w + pad_w
input_size = (input_t, input_h, input_w)
# https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
t_out = (input_size[0] + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
h_out = (input_size[1] + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
w_out = (input_size[2] + 2 * padding[2] - dilation[2] *
(kernel_size[2] - 1) - 1) // stride[2] + 1
self.init_out_size = (t_out, h_out, w_out)
else:
self.init_input_size = None
self.init_out_size = None
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, T, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_t * out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_t, out_h, out_w).
"""
if self.adaptive_padding:
x = self.adaptive_padding(x)
x = self.projection(x)
out_size = (x.shape[2], x.shape[3], x.shape[4])
x = x.flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x, out_size
|