AZIIIIIIIIZ's picture
Upload 1039 files
d670799 verified
raw
history blame
1.47 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmaction.models import STGCN
def test_stgcn_backbone():
"""Test STGCN backbone."""
mode = 'stgcn_spatial'
batch_size, num_person, num_frames = 2, 2, 150
# openpose-18 layout
num_joints = 18
model = STGCN(graph_cfg=dict(layout='openpose', mode=mode))
model.init_weights()
inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
output = model(inputs)
assert output.shape == torch.Size([2, 2, 256, 38, 18])
# nturgb+d layout
num_joints = 25
model = STGCN(graph_cfg=dict(layout='nturgb+d', mode=mode))
model.init_weights()
inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
output = model(inputs)
assert output.shape == torch.Size([2, 2, 256, 38, 25])
# coco layout
num_joints = 17
model = STGCN(graph_cfg=dict(layout='coco', mode=mode))
model.init_weights()
inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
output = model(inputs)
assert output.shape == torch.Size([2, 2, 256, 38, 17])
# custom settings
# instantiate STGCN++
model = STGCN(
graph_cfg=dict(layout='coco', mode='spatial'),
gcn_adaptive='init',
gcn_with_res=True,
tcn_type='mstcn')
model.init_weights()
output = model(inputs)
assert output.shape == torch.Size([2, 2, 256, 38, 17])