|
|
|
|
|
import numpy as np
|
|
|
import pytest
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from mmcv.transforms import to_tensor
|
|
|
|
|
|
from mmaction.models import CutmixBlending, MixupBlending, RandomBatchAugment
|
|
|
from mmaction.structures import ActionDataSample
|
|
|
|
|
|
|
|
|
def get_label(label_):
|
|
|
label = []
|
|
|
for idx, one_label in enumerate(label_):
|
|
|
data_sample = ActionDataSample()
|
|
|
data_sample.set_gt_label(label_[idx])
|
|
|
label.append(data_sample)
|
|
|
return label
|
|
|
|
|
|
|
|
|
def test_mixup():
|
|
|
alpha = 0.2
|
|
|
num_classes = 10
|
|
|
label = get_label([to_tensor(x) for x in range(4)])
|
|
|
mixup = MixupBlending(num_classes, alpha)
|
|
|
|
|
|
|
|
|
imgs = torch.randn(4, 4, 3, 32, 32)
|
|
|
mixed_imgs, mixed_label = mixup(imgs, label)
|
|
|
assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32))
|
|
|
assert len(mixed_label) == 4
|
|
|
|
|
|
|
|
|
imgs = torch.randn(4, 4, 2, 3, 32, 32)
|
|
|
label = get_label([to_tensor(x) for x in range(4)])
|
|
|
mixed_imgs, mixed_label = mixup(imgs, label)
|
|
|
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
|
|
|
assert len(mixed_label) == 4
|
|
|
|
|
|
|
|
|
imgs = torch.randn(4, 4, 2, 3, 32, 32)
|
|
|
label = get_label(F.one_hot(torch.arange(4), num_classes=num_classes))
|
|
|
mixed_imgs, mixed_label = mixup(imgs, label)
|
|
|
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
|
|
|
assert len(mixed_label) == 4
|
|
|
|
|
|
|
|
|
def test_cutmix():
|
|
|
alpha = 0.2
|
|
|
num_classes = 10
|
|
|
label = get_label([to_tensor(x) for x in range(4)])
|
|
|
cutmix = CutmixBlending(num_classes, alpha)
|
|
|
|
|
|
|
|
|
imgs = torch.randn(4, 4, 3, 32, 32)
|
|
|
mixed_imgs, mixed_label = cutmix(imgs, label)
|
|
|
assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32))
|
|
|
assert len(mixed_label) == 4
|
|
|
|
|
|
|
|
|
imgs = torch.randn(4, 4, 2, 3, 32, 32)
|
|
|
label = get_label([to_tensor(x) for x in range(4)])
|
|
|
mixed_imgs, mixed_label = cutmix(imgs, label)
|
|
|
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
|
|
|
assert len(mixed_label) == 4
|
|
|
|
|
|
|
|
|
imgs = torch.randn(4, 4, 2, 3, 32, 32)
|
|
|
label = get_label(F.one_hot(torch.arange(4), num_classes=num_classes))
|
|
|
mixed_imgs, mixed_label = cutmix(imgs, label)
|
|
|
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
|
|
|
assert len(mixed_label) == 4
|
|
|
|
|
|
|
|
|
def test_rand_blend():
|
|
|
alpha_mixup = 0.2
|
|
|
alpha_cutmix = 0.2
|
|
|
num_classes = 10
|
|
|
label = get_label([to_tensor(x) for x in range(4)])
|
|
|
blending_augs = [
|
|
|
dict(type='MixupBlending', alpha=alpha_mixup, num_classes=num_classes),
|
|
|
dict(
|
|
|
type='CutmixBlending', alpha=alpha_cutmix, num_classes=num_classes)
|
|
|
]
|
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
rand_mix = RandomBatchAugment(blending_augs, [0.5, 0.6])
|
|
|
|
|
|
|
|
|
rand_mix = RandomBatchAugment(blending_augs, probs=None)
|
|
|
assert rand_mix.probs is None
|
|
|
|
|
|
|
|
|
probs = [0.5, 0.4]
|
|
|
rand_mix = RandomBatchAugment(blending_augs, probs)
|
|
|
|
|
|
np.testing.assert_allclose(rand_mix.probs[-1], 0.1)
|
|
|
|
|
|
|
|
|
imgs = torch.randn(4, 4, 3, 32, 32)
|
|
|
mixed_imgs, mixed_label = rand_mix(imgs, label)
|
|
|
assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32))
|
|
|
assert len(mixed_label) == 4
|
|
|
|
|
|
imgs = torch.randn(4, 4, 2, 3, 32, 32)
|
|
|
label = get_label([to_tensor(x) for x in range(4)])
|
|
|
mixed_imgs, mixed_label = rand_mix(imgs, label)
|
|
|
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
|
|
|
assert len(mixed_label) == 4
|
|
|
|
|
|
|
|
|
imgs = torch.randn(4, 4, 2, 3, 32, 32)
|
|
|
label = get_label(F.one_hot(torch.arange(4), num_classes=num_classes))
|
|
|
mixed_imgs, mixed_label = rand_mix(imgs, label)
|
|
|
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
|
|
|
assert len(mixed_label) == 4
|
|
|
|