| |
|
| | import torch
|
| |
|
| | from mmaction.models import STGCN
|
| |
|
| |
|
| | def test_stgcn_backbone():
|
| | """Test STGCN backbone."""
|
| |
|
| | mode = 'stgcn_spatial'
|
| | batch_size, num_person, num_frames = 2, 2, 150
|
| |
|
| |
|
| | num_joints = 18
|
| | model = STGCN(graph_cfg=dict(layout='openpose', mode=mode))
|
| | model.init_weights()
|
| | inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
|
| | output = model(inputs)
|
| | assert output.shape == torch.Size([2, 2, 256, 38, 18])
|
| |
|
| |
|
| | num_joints = 25
|
| | model = STGCN(graph_cfg=dict(layout='nturgb+d', mode=mode))
|
| | model.init_weights()
|
| | inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
|
| | output = model(inputs)
|
| | assert output.shape == torch.Size([2, 2, 256, 38, 25])
|
| |
|
| |
|
| | num_joints = 17
|
| | model = STGCN(graph_cfg=dict(layout='coco', mode=mode))
|
| | model.init_weights()
|
| | inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
|
| | output = model(inputs)
|
| | assert output.shape == torch.Size([2, 2, 256, 38, 17])
|
| |
|
| |
|
| |
|
| | model = STGCN(
|
| | graph_cfg=dict(layout='coco', mode='spatial'),
|
| | gcn_adaptive='init',
|
| | gcn_with_res=True,
|
| | tcn_type='mstcn')
|
| | model.init_weights()
|
| | output = model(inputs)
|
| | assert output.shape == torch.Size([2, 2, 256, 38, 17])
|
| |
|