AZIIIIIIIIZ's picture
Upload 1039 files
d670799 verified
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmaction.models import TSMHead
def test_tsm_head():
"""Test loss method, layer construction, attributes and forward function in
tsm head."""
tsm_head = TSMHead(num_classes=4, in_channels=2048)
tsm_head.init_weights()
assert tsm_head.num_classes == 4
assert tsm_head.dropout_ratio == 0.8
assert tsm_head.in_channels == 2048
assert tsm_head.init_std == 0.001
assert tsm_head.consensus.dim == 1
assert tsm_head.spatial_type == 'avg'
assert isinstance(tsm_head.dropout, nn.Dropout)
assert tsm_head.dropout.p == tsm_head.dropout_ratio
assert isinstance(tsm_head.fc_cls, nn.Linear)
assert tsm_head.fc_cls.in_features == tsm_head.in_channels
assert tsm_head.fc_cls.out_features == tsm_head.num_classes
assert isinstance(tsm_head.avg_pool, nn.AdaptiveAvgPool2d)
assert tsm_head.avg_pool.output_size == 1
input_shape = (8, 2048, 7, 7)
feat = torch.rand(input_shape)
# tsm head inference with no init
num_segs = input_shape[0]
cls_scores = tsm_head(feat, num_segs)
assert cls_scores.shape == torch.Size([1, 4])
# tsm head inference with init
tsm_head = TSMHead(num_classes=4, in_channels=2048, temporal_pool=True)
tsm_head.init_weights()
cls_scores = tsm_head(feat, num_segs)
assert cls_scores.shape == torch.Size([2, 4])