Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule, build_conv_layer | |
| from mmengine.model import BaseModule | |
| from mmpose.registry import MODELS | |
| from ..utils.regularizations import WeightNormClipHook | |
| from .base_backbone import BaseBackbone | |
| class BasicTemporalBlock(BaseModule): | |
| """Basic block for VideoPose3D. | |
| Args: | |
| in_channels (int): Input channels of this block. | |
| out_channels (int): Output channels of this block. | |
| mid_channels (int): The output channels of conv1. Default: 1024. | |
| kernel_size (int): Size of the convolving kernel. Default: 3. | |
| dilation (int): Spacing between kernel elements. Default: 3. | |
| dropout (float): Dropout rate. Default: 0.25. | |
| causal (bool): Use causal convolutions instead of symmetric | |
| convolutions (for real-time applications). Default: False. | |
| residual (bool): Use residual connection. Default: True. | |
| use_stride_conv (bool): Use optimized TCN that designed | |
| specifically for single-frame batching, i.e. where batches have | |
| input length = receptive field, and output length = 1. This | |
| implementation replaces dilated convolutions with strided | |
| convolutions to avoid generating unused intermediate results. | |
| Default: False. | |
| conv_cfg (dict): dictionary to construct and config conv layer. | |
| Default: dict(type='Conv1d'). | |
| norm_cfg (dict): dictionary to construct and config norm layer. | |
| Default: dict(type='BN1d'). | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| mid_channels=1024, | |
| kernel_size=3, | |
| dilation=3, | |
| dropout=0.25, | |
| causal=False, | |
| residual=True, | |
| use_stride_conv=False, | |
| conv_cfg=dict(type='Conv1d'), | |
| norm_cfg=dict(type='BN1d'), | |
| init_cfg=None): | |
| # Protect mutable default arguments | |
| conv_cfg = copy.deepcopy(conv_cfg) | |
| norm_cfg = copy.deepcopy(norm_cfg) | |
| super().__init__(init_cfg=init_cfg) | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.mid_channels = mid_channels | |
| self.kernel_size = kernel_size | |
| self.dilation = dilation | |
| self.dropout = dropout | |
| self.causal = causal | |
| self.residual = residual | |
| self.use_stride_conv = use_stride_conv | |
| self.pad = (kernel_size - 1) * dilation // 2 | |
| if use_stride_conv: | |
| self.stride = kernel_size | |
| self.causal_shift = kernel_size // 2 if causal else 0 | |
| self.dilation = 1 | |
| else: | |
| self.stride = 1 | |
| self.causal_shift = kernel_size // 2 * dilation if causal else 0 | |
| self.conv1 = nn.Sequential( | |
| ConvModule( | |
| in_channels, | |
| mid_channels, | |
| kernel_size=kernel_size, | |
| stride=self.stride, | |
| dilation=self.dilation, | |
| bias='auto', | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg)) | |
| self.conv2 = nn.Sequential( | |
| ConvModule( | |
| mid_channels, | |
| out_channels, | |
| kernel_size=1, | |
| bias='auto', | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg)) | |
| if residual and in_channels != out_channels: | |
| self.short_cut = build_conv_layer(conv_cfg, in_channels, | |
| out_channels, 1) | |
| else: | |
| self.short_cut = None | |
| self.dropout = nn.Dropout(dropout) if dropout > 0 else None | |
| def forward(self, x): | |
| """Forward function.""" | |
| if self.use_stride_conv: | |
| assert self.causal_shift + self.kernel_size // 2 < x.shape[2] | |
| else: | |
| assert 0 <= self.pad + self.causal_shift < x.shape[2] - \ | |
| self.pad + self.causal_shift <= x.shape[2] | |
| out = self.conv1(x) | |
| if self.dropout is not None: | |
| out = self.dropout(out) | |
| out = self.conv2(out) | |
| if self.dropout is not None: | |
| out = self.dropout(out) | |
| if self.residual: | |
| if self.use_stride_conv: | |
| res = x[:, :, self.causal_shift + | |
| self.kernel_size // 2::self.kernel_size] | |
| else: | |
| res = x[:, :, | |
| (self.pad + self.causal_shift):(x.shape[2] - self.pad + | |
| self.causal_shift)] | |
| if self.short_cut is not None: | |
| res = self.short_cut(res) | |
| out = out + res | |
| return out | |
| class TCN(BaseBackbone): | |
| """TCN backbone. | |
| Temporal Convolutional Networks. | |
| More details can be found in the | |
| `paper <https://arxiv.org/abs/1811.11742>`__ . | |
| Args: | |
| in_channels (int): Number of input channels, which equals to | |
| num_keypoints * num_features. | |
| stem_channels (int): Number of feature channels. Default: 1024. | |
| num_blocks (int): NUmber of basic temporal convolutional blocks. | |
| Default: 2. | |
| kernel_sizes (Sequence[int]): Sizes of the convolving kernel of | |
| each basic block. Default: ``(3, 3, 3)``. | |
| dropout (float): Dropout rate. Default: 0.25. | |
| causal (bool): Use causal convolutions instead of symmetric | |
| convolutions (for real-time applications). | |
| Default: False. | |
| residual (bool): Use residual connection. Default: True. | |
| use_stride_conv (bool): Use TCN backbone optimized for | |
| single-frame batching, i.e. where batches have input length = | |
| receptive field, and output length = 1. This implementation | |
| replaces dilated convolutions with strided convolutions to avoid | |
| generating unused intermediate results. The weights are | |
| interchangeable with the reference implementation. Default: False | |
| conv_cfg (dict): dictionary to construct and config conv layer. | |
| Default: dict(type='Conv1d'). | |
| norm_cfg (dict): dictionary to construct and config norm layer. | |
| Default: dict(type='BN1d'). | |
| max_norm (float|None): if not None, the weight of convolution layers | |
| will be clipped to have a maximum norm of max_norm. | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: | |
| ``[ | |
| dict( | |
| type='Kaiming', | |
| mode='fan_in', | |
| nonlinearity='relu', | |
| layer=['Conv2d']), | |
| dict( | |
| type='Constant', | |
| val=1, | |
| layer=['_BatchNorm', 'GroupNorm']) | |
| ]`` | |
| Example: | |
| >>> from mmpose.models import TCN | |
| >>> import torch | |
| >>> self = TCN(in_channels=34) | |
| >>> self.eval() | |
| >>> inputs = torch.rand(1, 34, 243) | |
| >>> level_outputs = self.forward(inputs) | |
| >>> for level_out in level_outputs: | |
| ... print(tuple(level_out.shape)) | |
| (1, 1024, 235) | |
| (1, 1024, 217) | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| stem_channels=1024, | |
| num_blocks=2, | |
| kernel_sizes=(3, 3, 3), | |
| dropout=0.25, | |
| causal=False, | |
| residual=True, | |
| use_stride_conv=False, | |
| conv_cfg=dict(type='Conv1d'), | |
| norm_cfg=dict(type='BN1d'), | |
| max_norm=None, | |
| init_cfg=[ | |
| dict( | |
| type='Kaiming', | |
| mode='fan_in', | |
| nonlinearity='relu', | |
| layer=['Conv2d']), | |
| dict( | |
| type='Constant', | |
| val=1, | |
| layer=['_BatchNorm', 'GroupNorm']) | |
| ]): | |
| # Protect mutable default arguments | |
| conv_cfg = copy.deepcopy(conv_cfg) | |
| norm_cfg = copy.deepcopy(norm_cfg) | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.stem_channels = stem_channels | |
| self.num_blocks = num_blocks | |
| self.kernel_sizes = kernel_sizes | |
| self.dropout = dropout | |
| self.causal = causal | |
| self.residual = residual | |
| self.use_stride_conv = use_stride_conv | |
| self.max_norm = max_norm | |
| assert num_blocks == len(kernel_sizes) - 1 | |
| for ks in kernel_sizes: | |
| assert ks % 2 == 1, 'Only odd filter widths are supported.' | |
| self.expand_conv = ConvModule( | |
| in_channels, | |
| stem_channels, | |
| kernel_size=kernel_sizes[0], | |
| stride=kernel_sizes[0] if use_stride_conv else 1, | |
| bias='auto', | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg) | |
| dilation = kernel_sizes[0] | |
| self.tcn_blocks = nn.ModuleList() | |
| for i in range(1, num_blocks + 1): | |
| self.tcn_blocks.append( | |
| BasicTemporalBlock( | |
| in_channels=stem_channels, | |
| out_channels=stem_channels, | |
| mid_channels=stem_channels, | |
| kernel_size=kernel_sizes[i], | |
| dilation=dilation, | |
| dropout=dropout, | |
| causal=causal, | |
| residual=residual, | |
| use_stride_conv=use_stride_conv, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg)) | |
| dilation *= kernel_sizes[i] | |
| if self.max_norm is not None: | |
| # Apply weight norm clip to conv layers | |
| weight_clip = WeightNormClipHook(self.max_norm) | |
| for module in self.modules(): | |
| if isinstance(module, nn.modules.conv._ConvNd): | |
| weight_clip.register(module) | |
| self.dropout = nn.Dropout(dropout) if dropout > 0 else None | |
| def forward(self, x): | |
| """Forward function.""" | |
| x = self.expand_conv(x) | |
| if self.dropout is not None: | |
| x = self.dropout(x) | |
| outs = [] | |
| for i in range(self.num_blocks): | |
| x = self.tcn_blocks[i](x) | |
| outs.append(x) | |
| return tuple(outs) | |