File size: 2,279 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmaction.models import TRNHead
def test_trn_head():
"""Test loss method, layer construction, attributes and forward function in
trn head."""
from mmaction.models.heads.trn_head import (RelationModule,
RelationModuleMultiScale)
trn_head = TRNHead(num_classes=4, in_channels=2048, relation_type='TRN')
trn_head.init_weights()
assert trn_head.num_classes == 4
assert trn_head.dropout_ratio == 0.8
assert trn_head.in_channels == 2048
assert trn_head.init_std == 0.001
assert trn_head.spatial_type == 'avg'
relation_module = trn_head.consensus
assert isinstance(relation_module, RelationModule)
assert relation_module.hidden_dim == 256
assert isinstance(relation_module.classifier[3], nn.Linear)
assert relation_module.classifier[3].out_features == trn_head.num_classes
assert trn_head.dropout.p == trn_head.dropout_ratio
assert isinstance(trn_head.dropout, nn.Dropout)
assert isinstance(trn_head.fc_cls, nn.Linear)
assert trn_head.fc_cls.in_features == trn_head.in_channels
assert trn_head.fc_cls.out_features == trn_head.hidden_dim
assert isinstance(trn_head.avg_pool, nn.AdaptiveAvgPool2d)
assert trn_head.avg_pool.output_size == 1
input_shape = (8, 2048, 7, 7)
feat = torch.rand(input_shape)
# tsm head inference with no init
num_segs = input_shape[0]
cls_scores = trn_head(feat, num_segs)
assert cls_scores.shape == torch.Size([1, 4])
# tsm head inference with init
trn_head = TRNHead(
num_classes=4,
in_channels=2048,
num_segments=8,
relation_type='TRNMultiScale')
trn_head.init_weights()
assert isinstance(trn_head.consensus, RelationModuleMultiScale)
assert trn_head.consensus.scales == range(8, 1, -1)
cls_scores = trn_head(feat, num_segs)
assert cls_scores.shape == torch.Size([1, 4])
with pytest.raises(ValueError):
trn_head = TRNHead(
num_classes=4,
in_channels=2048,
num_segments=8,
relation_type='RelationModlue')
|