File size: 4,058 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from mmcv.transforms import to_tensor
from mmaction.models import CutmixBlending, MixupBlending, RandomBatchAugment
from mmaction.structures import ActionDataSample
def get_label(label_):
label = []
for idx, one_label in enumerate(label_):
data_sample = ActionDataSample()
data_sample.set_gt_label(label_[idx])
label.append(data_sample)
return label
def test_mixup():
alpha = 0.2
num_classes = 10
label = get_label([to_tensor(x) for x in range(4)])
mixup = MixupBlending(num_classes, alpha)
# NCHW imgs
imgs = torch.randn(4, 4, 3, 32, 32)
mixed_imgs, mixed_label = mixup(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32))
assert len(mixed_label) == 4
# NCTHW imgs
imgs = torch.randn(4, 4, 2, 3, 32, 32)
label = get_label([to_tensor(x) for x in range(4)])
mixed_imgs, mixed_label = mixup(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
assert len(mixed_label) == 4
# multi-label with one-hot tensor as label
imgs = torch.randn(4, 4, 2, 3, 32, 32)
label = get_label(F.one_hot(torch.arange(4), num_classes=num_classes))
mixed_imgs, mixed_label = mixup(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
assert len(mixed_label) == 4
def test_cutmix():
alpha = 0.2
num_classes = 10
label = get_label([to_tensor(x) for x in range(4)])
cutmix = CutmixBlending(num_classes, alpha)
# NCHW imgs
imgs = torch.randn(4, 4, 3, 32, 32)
mixed_imgs, mixed_label = cutmix(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32))
assert len(mixed_label) == 4
# NCTHW imgs
imgs = torch.randn(4, 4, 2, 3, 32, 32)
label = get_label([to_tensor(x) for x in range(4)])
mixed_imgs, mixed_label = cutmix(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
assert len(mixed_label) == 4
# multi-label with one-hot tensor as label
imgs = torch.randn(4, 4, 2, 3, 32, 32)
label = get_label(F.one_hot(torch.arange(4), num_classes=num_classes))
mixed_imgs, mixed_label = cutmix(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
assert len(mixed_label) == 4
def test_rand_blend():
alpha_mixup = 0.2
alpha_cutmix = 0.2
num_classes = 10
label = get_label([to_tensor(x) for x in range(4)])
blending_augs = [
dict(type='MixupBlending', alpha=alpha_mixup, num_classes=num_classes),
dict(
type='CutmixBlending', alpha=alpha_cutmix, num_classes=num_classes)
]
# test assertion
with pytest.raises(AssertionError):
rand_mix = RandomBatchAugment(blending_augs, [0.5, 0.6])
# mixup, cutmix
rand_mix = RandomBatchAugment(blending_augs, probs=None)
assert rand_mix.probs is None
# mixup, cutmix and None
probs = [0.5, 0.4]
rand_mix = RandomBatchAugment(blending_augs, probs)
np.testing.assert_allclose(rand_mix.probs[-1], 0.1)
# test call
imgs = torch.randn(4, 4, 3, 32, 32) # NCHW imgs
mixed_imgs, mixed_label = rand_mix(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32))
assert len(mixed_label) == 4
imgs = torch.randn(4, 4, 2, 3, 32, 32) # NCTHW imgs
label = get_label([to_tensor(x) for x in range(4)])
mixed_imgs, mixed_label = rand_mix(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
assert len(mixed_label) == 4
# multi-label with one-hot tensor as label
imgs = torch.randn(4, 4, 2, 3, 32, 32)
label = get_label(F.one_hot(torch.arange(4), num_classes=num_classes))
mixed_imgs, mixed_label = rand_mix(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
assert len(mixed_label) == 4
|