File size: 4,058 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from mmcv.transforms import to_tensor

from mmaction.models import CutmixBlending, MixupBlending, RandomBatchAugment
from mmaction.structures import ActionDataSample


def get_label(label_):
    label = []
    for idx, one_label in enumerate(label_):
        data_sample = ActionDataSample()
        data_sample.set_gt_label(label_[idx])
        label.append(data_sample)
    return label


def test_mixup():
    alpha = 0.2
    num_classes = 10
    label = get_label([to_tensor(x) for x in range(4)])
    mixup = MixupBlending(num_classes, alpha)

    # NCHW imgs
    imgs = torch.randn(4, 4, 3, 32, 32)
    mixed_imgs, mixed_label = mixup(imgs, label)
    assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32))
    assert len(mixed_label) == 4

    # NCTHW imgs
    imgs = torch.randn(4, 4, 2, 3, 32, 32)
    label = get_label([to_tensor(x) for x in range(4)])
    mixed_imgs, mixed_label = mixup(imgs, label)
    assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
    assert len(mixed_label) == 4

    # multi-label with one-hot tensor as label
    imgs = torch.randn(4, 4, 2, 3, 32, 32)
    label = get_label(F.one_hot(torch.arange(4), num_classes=num_classes))
    mixed_imgs, mixed_label = mixup(imgs, label)
    assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
    assert len(mixed_label) == 4


def test_cutmix():
    alpha = 0.2
    num_classes = 10
    label = get_label([to_tensor(x) for x in range(4)])
    cutmix = CutmixBlending(num_classes, alpha)

    # NCHW imgs
    imgs = torch.randn(4, 4, 3, 32, 32)
    mixed_imgs, mixed_label = cutmix(imgs, label)
    assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32))
    assert len(mixed_label) == 4

    # NCTHW imgs
    imgs = torch.randn(4, 4, 2, 3, 32, 32)
    label = get_label([to_tensor(x) for x in range(4)])
    mixed_imgs, mixed_label = cutmix(imgs, label)
    assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
    assert len(mixed_label) == 4

    # multi-label with one-hot tensor as label
    imgs = torch.randn(4, 4, 2, 3, 32, 32)
    label = get_label(F.one_hot(torch.arange(4), num_classes=num_classes))
    mixed_imgs, mixed_label = cutmix(imgs, label)
    assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
    assert len(mixed_label) == 4


def test_rand_blend():
    alpha_mixup = 0.2
    alpha_cutmix = 0.2
    num_classes = 10
    label = get_label([to_tensor(x) for x in range(4)])
    blending_augs = [
        dict(type='MixupBlending', alpha=alpha_mixup, num_classes=num_classes),
        dict(
            type='CutmixBlending', alpha=alpha_cutmix, num_classes=num_classes)
    ]

    # test assertion
    with pytest.raises(AssertionError):
        rand_mix = RandomBatchAugment(blending_augs, [0.5, 0.6])

    # mixup, cutmix
    rand_mix = RandomBatchAugment(blending_augs, probs=None)
    assert rand_mix.probs is None

    # mixup, cutmix and None
    probs = [0.5, 0.4]
    rand_mix = RandomBatchAugment(blending_augs, probs)

    np.testing.assert_allclose(rand_mix.probs[-1], 0.1)

    # test call
    imgs = torch.randn(4, 4, 3, 32, 32)  # NCHW imgs
    mixed_imgs, mixed_label = rand_mix(imgs, label)
    assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32))
    assert len(mixed_label) == 4

    imgs = torch.randn(4, 4, 2, 3, 32, 32)  # NCTHW imgs
    label = get_label([to_tensor(x) for x in range(4)])
    mixed_imgs, mixed_label = rand_mix(imgs, label)
    assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
    assert len(mixed_label) == 4

    # multi-label with one-hot tensor as label
    imgs = torch.randn(4, 4, 2, 3, 32, 32)
    label = get_label(F.one_hot(torch.arange(4), num_classes=num_classes))
    mixed_imgs, mixed_label = rand_mix(imgs, label)
    assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
    assert len(mixed_label) == 4