|
|
|
|
|
import pytest
|
|
|
import torch
|
|
|
|
|
|
from mmaction.models import SimpleMeanAdapter, TransformerAdapter
|
|
|
|
|
|
|
|
|
def test_transformer_adapter():
|
|
|
"""Test transformer adapter."""
|
|
|
with pytest.raises(RuntimeError):
|
|
|
num_segs_model = 8
|
|
|
num_segs_features = 9
|
|
|
adapter = TransformerAdapter(
|
|
|
num_segs=num_segs_model,
|
|
|
transformer_width=64,
|
|
|
transformer_heads=8,
|
|
|
transformer_layers=2)
|
|
|
features = torch.randn(2, num_segs_features, 64)
|
|
|
adapter(features)
|
|
|
|
|
|
num_segs = 8
|
|
|
adapter = TransformerAdapter(
|
|
|
num_segs=num_segs,
|
|
|
transformer_width=64,
|
|
|
transformer_heads=8,
|
|
|
transformer_layers=2)
|
|
|
adapter.init_weights()
|
|
|
features = torch.randn(2, num_segs, 64)
|
|
|
adapted_features = adapter(features)
|
|
|
assert adapted_features.shape == torch.Size([2, 64])
|
|
|
|
|
|
|
|
|
def test_simple_mean_adapter():
|
|
|
"""Test simple mean adapter."""
|
|
|
|
|
|
adapter = SimpleMeanAdapter(dim=1)
|
|
|
features = torch.randn(2, 8, 64)
|
|
|
adapted_features = adapter(features)
|
|
|
assert adapted_features.shape == torch.Size([2, 64])
|
|
|
|
|
|
adapter = SimpleMeanAdapter(dim=(1, 2))
|
|
|
features = torch.randn(2, 8, 2, 64)
|
|
|
adapted_features = adapter(features)
|
|
|
assert adapted_features.shape == torch.Size([2, 64])
|
|
|
|