File size: 1,375 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmaction.models import SimpleMeanAdapter, TransformerAdapter
def test_transformer_adapter():
"""Test transformer adapter."""
with pytest.raises(RuntimeError):
num_segs_model = 8
num_segs_features = 9
adapter = TransformerAdapter(
num_segs=num_segs_model,
transformer_width=64,
transformer_heads=8,
transformer_layers=2)
features = torch.randn(2, num_segs_features, 64)
adapter(features)
num_segs = 8
adapter = TransformerAdapter(
num_segs=num_segs,
transformer_width=64,
transformer_heads=8,
transformer_layers=2)
adapter.init_weights()
features = torch.randn(2, num_segs, 64)
adapted_features = adapter(features)
assert adapted_features.shape == torch.Size([2, 64])
def test_simple_mean_adapter():
"""Test simple mean adapter."""
adapter = SimpleMeanAdapter(dim=1)
features = torch.randn(2, 8, 64)
adapted_features = adapter(features)
assert adapted_features.shape == torch.Size([2, 64])
adapter = SimpleMeanAdapter(dim=(1, 2))
features = torch.randn(2, 8, 2, 64)
adapted_features = adapter(features)
assert adapted_features.shape == torch.Size([2, 64])
|