|
|
|
|
|
import torch
|
|
|
|
|
|
from mmaction.models import STGCN
|
|
|
|
|
|
|
|
|
def test_stgcn_backbone():
|
|
|
"""Test STGCN backbone."""
|
|
|
|
|
|
mode = 'stgcn_spatial'
|
|
|
batch_size, num_person, num_frames = 2, 2, 150
|
|
|
|
|
|
|
|
|
num_joints = 18
|
|
|
model = STGCN(graph_cfg=dict(layout='openpose', mode=mode))
|
|
|
model.init_weights()
|
|
|
inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
|
|
|
output = model(inputs)
|
|
|
assert output.shape == torch.Size([2, 2, 256, 38, 18])
|
|
|
|
|
|
|
|
|
num_joints = 25
|
|
|
model = STGCN(graph_cfg=dict(layout='nturgb+d', mode=mode))
|
|
|
model.init_weights()
|
|
|
inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
|
|
|
output = model(inputs)
|
|
|
assert output.shape == torch.Size([2, 2, 256, 38, 25])
|
|
|
|
|
|
|
|
|
num_joints = 17
|
|
|
model = STGCN(graph_cfg=dict(layout='coco', mode=mode))
|
|
|
model.init_weights()
|
|
|
inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
|
|
|
output = model(inputs)
|
|
|
assert output.shape == torch.Size([2, 2, 256, 38, 17])
|
|
|
|
|
|
|
|
|
|
|
|
model = STGCN(
|
|
|
graph_cfg=dict(layout='coco', mode='spatial'),
|
|
|
gcn_adaptive='init',
|
|
|
gcn_with_res=True,
|
|
|
tcn_type='mstcn')
|
|
|
model.init_weights()
|
|
|
output = model(inputs)
|
|
|
assert output.shape == torch.Size([2, 2, 256, 38, 17])
|
|
|
|