# Copyright (c) OpenMMLab. All rights reserved. import copy as cp from typing import Dict, List, Optional, Union import torch import torch.nn as nn from mmcv.cnn import build_activation_layer, build_norm_layer from mmengine.model import BaseModule, ModuleList, Sequential class unit_gcn(BaseModule): """The basic unit of graph convolutional network. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. A (torch.Tensor): The adjacency matrix defined in the graph with shape of `(num_subsets, num_nodes, num_nodes)`. adaptive (str): The strategy for adapting the weights of the adjacency matrix. Defaults to ``'importance'``. conv_pos (str): The position of the 1x1 2D conv. Defaults to ``'pre'``. with_res (bool): Whether to use residual connection. Defaults to False. norm (str): The name of norm layer. Defaults to ``'BN'``. act (str): The name of activation layer. Defaults to ``'Relu'``. init_cfg (dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, in_channels: int, out_channels: int, A: torch.Tensor, adaptive: str = 'importance', conv_pos: str = 'pre', with_res: bool = False, norm: str = 'BN', act: str = 'ReLU', init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.out_channels = out_channels self.num_subsets = A.size(0) assert adaptive in [None, 'init', 'offset', 'importance'] self.adaptive = adaptive assert conv_pos in ['pre', 'post'] self.conv_pos = conv_pos self.with_res = with_res self.norm_cfg = norm if isinstance(norm, dict) else dict(type=norm) self.act_cfg = act if isinstance(act, dict) else dict(type=act) self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] self.act = build_activation_layer(self.act_cfg) if self.adaptive == 'init': self.A = nn.Parameter(A.clone()) else: self.register_buffer('A', A) if self.adaptive in ['offset', 'importance']: self.PA = nn.Parameter(A.clone()) if self.adaptive == 'offset': nn.init.uniform_(self.PA, -1e-6, 1e-6) elif self.adaptive == 'importance': nn.init.constant_(self.PA, 1) if self.conv_pos == 'pre': self.conv = nn.Conv2d(in_channels, out_channels * A.size(0), 1) elif self.conv_pos == 'post': self.conv = nn.Conv2d(A.size(0) * in_channels, out_channels, 1) if self.with_res: if in_channels != out_channels: self.down = Sequential( nn.Conv2d(in_channels, out_channels, 1), build_norm_layer(self.norm_cfg, out_channels)[1]) else: self.down = lambda x: x def forward(self, x: torch.Tensor) -> torch.Tensor: """Defines the computation performed at every call.""" n, c, t, v = x.shape res = self.down(x) if self.with_res else 0 A_switch = {None: self.A, 'init': self.A} if hasattr(self, 'PA'): A_switch.update({ 'offset': self.A + self.PA, 'importance': self.A * self.PA }) A = A_switch[self.adaptive] if self.conv_pos == 'pre': x = self.conv(x) x = x.view(n, self.num_subsets, -1, t, v) x = torch.einsum('nkctv,kvw->nctw', (x, A)).contiguous() elif self.conv_pos == 'post': x = torch.einsum('nctv,kvw->nkctw', (x, A)).contiguous() x = x.view(n, -1, t, v) x = self.conv(x) return self.act(self.bn(x) + res) class unit_aagcn(BaseModule): """The graph convolution unit of AAGCN. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. A (torch.Tensor): The adjacency matrix defined in the graph with shape of `(num_subsets, num_joints, num_joints)`. coff_embedding (int): The coefficient for downscaling the embedding dimension. Defaults to 4. adaptive (bool): Whether to use adaptive graph convolutional layer. Defaults to True. attention (bool): Whether to use the STC-attention module. Defaults to True. init_cfg (dict or list[dict]): Initialization config dict. Defaults to ``[ dict(type='Constant', layer='BatchNorm2d', val=1, override=dict(type='Constant', name='bn', val=1e-6)), dict(type='Kaiming', layer='Conv2d', mode='fan_out'), dict(type='ConvBranch', name='conv_d') ]``. """ def __init__( self, in_channels: int, out_channels: int, A: torch.Tensor, coff_embedding: int = 4, adaptive: bool = True, attention: bool = True, init_cfg: Optional[Union[Dict, List[Dict]]] = [ dict( type='Constant', layer='BatchNorm2d', val=1, override=dict(type='Constant', name='bn', val=1e-6)), dict(type='Kaiming', layer='Conv2d', mode='fan_out'), dict(type='ConvBranch', name='conv_d') ] ) -> None: if attention: attention_init_cfg = [ dict( type='Constant', layer='Conv1d', val=0, override=dict(type='Xavier', name='conv_sa')), dict( type='Kaiming', layer='Linear', mode='fan_in', override=dict(type='Constant', val=0, name='fc2c')) ] init_cfg = cp.copy(init_cfg) init_cfg.extend(attention_init_cfg) super(unit_aagcn, self).__init__(init_cfg=init_cfg) inter_channels = out_channels // coff_embedding self.inter_c = inter_channels self.out_c = out_channels self.in_c = in_channels self.num_subset = A.shape[0] self.adaptive = adaptive self.attention = attention num_joints = A.shape[-1] self.conv_d = ModuleList() for i in range(self.num_subset): self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1)) if self.adaptive: self.A = nn.Parameter(A) self.alpha = nn.Parameter(torch.zeros(1)) self.conv_a = ModuleList() self.conv_b = ModuleList() for i in range(self.num_subset): self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1)) self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1)) else: self.register_buffer('A', A) if self.attention: self.conv_ta = nn.Conv1d(out_channels, 1, 9, padding=4) # s attention ker_joint = num_joints if num_joints % 2 else num_joints - 1 pad = (ker_joint - 1) // 2 self.conv_sa = nn.Conv1d(out_channels, 1, ker_joint, padding=pad) # channel attention rr = 2 self.fc1c = nn.Linear(out_channels, out_channels // rr) self.fc2c = nn.Linear(out_channels // rr, out_channels) self.down = lambda x: x if in_channels != out_channels: self.down = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1), nn.BatchNorm2d(out_channels)) self.bn = nn.BatchNorm2d(out_channels) self.tan = nn.Tanh() self.sigmoid = nn.Sigmoid() self.relu = nn.ReLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: """Defines the computation performed at every call.""" N, C, T, V = x.size() y = None if self.adaptive: for i in range(self.num_subset): A1 = self.conv_a[i](x).permute(0, 3, 1, 2).contiguous().view( N, V, self.inter_c * T) A2 = self.conv_b[i](x).view(N, self.inter_c * T, V) A1 = self.tan(torch.matmul(A1, A2) / A1.size(-1)) # N V V A1 = self.A[i] + A1 * self.alpha A2 = x.view(N, C * T, V) z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V)) y = z + y if y is not None else z else: for i in range(self.num_subset): A1 = self.A[i] A2 = x.view(N, C * T, V) z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V)) y = z + y if y is not None else z y = self.relu(self.bn(y) + self.down(x)) if self.attention: # spatial attention first se = y.mean(-2) # N C V se1 = self.sigmoid(self.conv_sa(se)) # N 1 V y = y * se1.unsqueeze(-2) + y # then temporal attention se = y.mean(-1) # N C T se1 = self.sigmoid(self.conv_ta(se)) # N 1 T y = y * se1.unsqueeze(-1) + y # then spatial temporal attention ?? se = y.mean(-1).mean(-1) # N C se1 = self.relu(self.fc1c(se)) se2 = self.sigmoid(self.fc2c(se1)) # N C y = y * se2.unsqueeze(-1).unsqueeze(-1) + y # A little bit weird return y class unit_tcn(BaseModule): """The basic unit of temporal convolutional network. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. kernel_size (int): Size of the temporal convolution kernel. Defaults to 9. stride (int): Stride of the temporal convolution. Defaults to 1. dilation (int): Spacing between temporal kernel elements. Defaults to 1. norm (str): The name of norm layer. Defaults to ``'BN'``. dropout (float): Dropout probability. Defaults to 0. init_cfg (dict or list[dict]): Initialization config dict. Defaults to ``[ dict(type='Constant', layer='BatchNorm2d', val=1), dict(type='Kaiming', layer='Conv2d', mode='fan_out') ]``. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 9, stride: int = 1, dilation: int = 1, norm: str = 'BN', dropout: float = 0, init_cfg: Union[Dict, List[Dict]] = [ dict(type='Constant', layer='BatchNorm2d', val=1), dict(type='Kaiming', layer='Conv2d', mode='fan_out') ] ) -> None: super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.out_channels = out_channels self.norm_cfg = norm if isinstance(norm, dict) else dict(type=norm) pad = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0), stride=(stride, 1), dilation=(dilation, 1)) self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] \ if norm is not None else nn.Identity() self.drop = nn.Dropout(dropout, inplace=True) self.stride = stride def forward(self, x: torch.Tensor) -> torch.Tensor: """Defines the computation performed at every call.""" return self.drop(self.bn(self.conv(x))) class mstcn(BaseModule): """The multi-scale temporal convolutional network. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. mid_channels (int): Number of middle channels. Defaults to None. dropout (float): Dropout probability. Defaults to 0. ms_cfg (list): The config of multi-scale branches. Defaults to ``[(3, 1), (3, 2), (3, 3), (3, 4), ('max', 3), '1x1']``. stride (int): Stride of the temporal convolution. Defaults to 1. init_cfg (dict or list[dict]): Initialization config dict. Defaults to None. """ def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, dropout: float = 0., ms_cfg: List = [(3, 1), (3, 2), (3, 3), (3, 4), ('max', 3), '1x1'], stride: int = 1, init_cfg: Union[Dict, List[Dict]] = None) -> None: super().__init__(init_cfg=init_cfg) # Multiple branches of temporal convolution self.ms_cfg = ms_cfg num_branches = len(ms_cfg) self.num_branches = num_branches self.in_channels = in_channels self.out_channels = out_channels self.act = nn.ReLU() if mid_channels is None: mid_channels = out_channels // num_branches rem_mid_channels = out_channels - mid_channels * (num_branches - 1) else: assert isinstance(mid_channels, float) and mid_channels > 0 mid_channels = int(out_channels * mid_channels) rem_mid_channels = mid_channels self.mid_channels = mid_channels self.rem_mid_channels = rem_mid_channels branches = [] for i, cfg in enumerate(ms_cfg): branch_c = rem_mid_channels if i == 0 else mid_channels if cfg == '1x1': branches.append( nn.Conv2d( in_channels, branch_c, kernel_size=1, stride=(stride, 1))) continue assert isinstance(cfg, tuple) if cfg[0] == 'max': branches.append( Sequential( nn.Conv2d(in_channels, branch_c, kernel_size=1), nn.BatchNorm2d(branch_c), self.act, nn.MaxPool2d( kernel_size=(cfg[1], 1), stride=(stride, 1), padding=(1, 0)))) continue assert isinstance(cfg[0], int) and isinstance(cfg[1], int) branch = Sequential( nn.Conv2d(in_channels, branch_c, kernel_size=1), nn.BatchNorm2d(branch_c), self.act, unit_tcn( branch_c, branch_c, kernel_size=cfg[0], stride=stride, dilation=cfg[1], norm=None)) branches.append(branch) self.branches = ModuleList(branches) tin_channels = mid_channels * (num_branches - 1) + rem_mid_channels self.transform = Sequential( nn.BatchNorm2d(tin_channels), self.act, nn.Conv2d(tin_channels, out_channels, kernel_size=1)) self.bn = nn.BatchNorm2d(out_channels) self.drop = nn.Dropout(dropout, inplace=True) def inner_forward(self, x: torch.Tensor) -> torch.Tensor: """Defines the computation performed at every call.""" N, C, T, V = x.shape branch_outs = [] for tempconv in self.branches: out = tempconv(x) branch_outs.append(out) feat = torch.cat(branch_outs, dim=1) feat = self.transform(feat) return feat def forward(self, x: torch.Tensor) -> torch.Tensor: """Defines the computation performed at every call.""" out = self.inner_forward(x) out = self.bn(out) return self.drop(out)