| | |
| | |
| | |
| | |
| | |
| |
|
| | import random |
| |
|
| | import numpy as np |
| | import pytest |
| | import torch |
| | from PIL import Image |
| | from tests.test_utils import ( |
| | assert_expected, |
| | get_asset_path, |
| | set_rng_seed, |
| | skip_if_no_ffmpeg, |
| | ) |
| | from torchmultimodal.transforms.mae_transform import ( |
| | AudioEvalTransform, |
| | AudioFineTuneTransform, |
| | AudioPretrainTransform, |
| | ImageEvalTransform, |
| | ImagePretrainTransform, |
| | MixUpCutMix, |
| | RandAug, |
| | ) |
| | from torchvision import transforms |
| |
|
| | IMAGE_PATH = "tests/assets/test_image.jpg" |
| | WAV_PATH = "tests/assets/kaldi_file_8000.wav" |
| | MIXUP_WAV_PATH = "tests/assets/sinewave.wav" |
| |
|
| |
|
| | @pytest.fixture |
| | def image(): |
| | return Image.open(get_asset_path("test_image.jpg")) |
| |
|
| |
|
| | class TestImageEvalTransform: |
| | @pytest.fixture |
| | def transform(self): |
| | return ImageEvalTransform(input_size=224) |
| |
|
| | def test_transform(self, transform, image): |
| | actual = transform(image) |
| | assert_expected(actual.size(), (3, 224, 224)) |
| | assert_expected(actual.mean().item(), -0.4967, atol=0.0001, rtol=0.0) |
| |
|
| | def test_transform_list(self, transform, image): |
| | actual = transform([image]) |
| | assert_expected(actual.size(), (1, 3, 224, 224)) |
| | assert_expected(actual.mean().item(), -0.4967, atol=0.0001, rtol=0.0) |
| |
|
| |
|
| | class TestImagePretrainTransform: |
| | @pytest.fixture |
| | def transform(self): |
| | return ImagePretrainTransform(input_size=224) |
| |
|
| | @pytest.fixture(autouse=True) |
| | def set_seed(self): |
| | set_rng_seed(0) |
| |
|
| | def test_transform(self, transform, image): |
| | actual = transform(image) |
| | assert_expected(actual.size(), (3, 224, 224)) |
| | assert_expected(actual.mean().item(), -0.4625, atol=0.0001, rtol=0.0) |
| |
|
| | def test_transform_list(self, transform, image): |
| | actual = transform([image]) |
| | assert_expected(actual.size(), (1, 3, 224, 224)) |
| | assert_expected(actual.mean().item(), -0.4625, atol=0.0001, rtol=0.0) |
| |
|
| |
|
| | class TestMixup: |
| | @pytest.fixture(autouse=True) |
| | def set_seed(self): |
| | torch.manual_seed(0) |
| | np.random.seed(0) |
| |
|
| | @pytest.fixture |
| | def inputs(self): |
| | return torch.Tensor( |
| | [ |
| | [ |
| | [ |
| | [1.0, 2.0, 1.0, 2.0], |
| | [2.0, 1.0, 1.0, 2.0], |
| | [2.0, 2.0, 1.0, 0.0], |
| | [7.0, 5.0, 4.0, 9.0], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [2.0, 5.0, 1.0, 2.0], |
| | [1.0, 5.0, 1.0, 2.0], |
| | [3.0, 1.0, 9.0, 2.0], |
| | [5.0, 2.0, 1.0, 2.0], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [3.0, 4.0, 1.0, 2.0], |
| | [6.0, 3.0, 1.0, 2.0], |
| | [6.0, 5.0, 4.0, 1.0], |
| | [7.0, 7.0, 9.0, 7.0], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [1.0, 1.0, 1.0, 2.0], |
| | [1.0, 1.0, 1.0, 2.0], |
| | [8.0, 0.0, 4.0, 6.0], |
| | [0.0, 0.0, 3.0, 2.0], |
| | ] |
| | ], |
| | ] |
| | ) |
| |
|
| | @pytest.fixture |
| | def targets(self): |
| | return torch.Tensor([1, 2, 3, 0]).to(dtype=torch.long) |
| |
|
| | def test_mixup(self, inputs, targets): |
| | mixup = MixUpCutMix( |
| | augment_prob=1, mixup_alpha=1, switch_prob=0, cutmix_alpha=0, classes=4 |
| | ) |
| | actual_images, actual_targets = mixup(images=inputs, targets=targets) |
| | assert_expected( |
| | actual_images, |
| | torch.Tensor( |
| | ( |
| | [ |
| | [ |
| | [ |
| | [1.0000, 1.5626, 1.0000, 2.0000], |
| | [1.5626, 1.0000, 1.0000, 2.0000], |
| | [4.6245, 1.1252, 2.3123, 2.6245], |
| | [3.9381, 2.8129, 3.5626, 5.9381], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [2.4374, 4.5626, 1.0000, 2.0000], |
| | [3.1871, 4.1252, 1.0000, 2.0000], |
| | [4.3123, 2.7497, 6.8129, 1.5626], |
| | [5.8748, 4.1871, 4.4993, 4.1871], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [2.5626, 4.4374, 1.0000, 2.0000], |
| | [3.8129, 3.8748, 1.0000, 2.0000], |
| | [4.6877, 3.2503, 6.1871, 1.4374], |
| | [6.1252, 4.8129, 5.5007, 4.8129], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [1.0000, 1.4374, 1.0000, 2.0000], |
| | [1.4374, 1.0000, 1.0000, 2.0000], |
| | [5.3755, 0.8748, 2.6877, 3.3755], |
| | [3.0619, 2.1871, 3.4374, 5.0619], |
| | ] |
| | ], |
| | ] |
| | ) |
| | ), |
| | atol=0.0001, |
| | rtol=0.0, |
| | ) |
| | assert_expected( |
| | actual_targets, |
| | torch.Tensor( |
| | [ |
| | [0.4187, 0.5313, 0.0250, 0.0250], |
| | [0.0250, 0.0250, 0.5313, 0.4187], |
| | [0.0250, 0.0250, 0.4187, 0.5313], |
| | [0.5313, 0.4187, 0.0250, 0.0250], |
| | ] |
| | ), |
| | atol=0.0001, |
| | rtol=0.0, |
| | ) |
| |
|
| | def test_cutmix(self, inputs, targets): |
| | cutmix = MixUpCutMix( |
| | augment_prob=1, mixup_alpha=1, switch_prob=1, cutmix_alpha=1, classes=4 |
| | ) |
| | actual_images, actual_targets = cutmix(images=inputs, targets=targets) |
| | assert_expected( |
| | actual_images, |
| | torch.Tensor( |
| | ( |
| | [ |
| | [ |
| | [ |
| | [1.0, 2.0, 1.0, 2.0], |
| | [2.0, 1.0, 1.0, 2.0], |
| | [8.0, 2.0, 1.0, 0.0], |
| | [0.0, 5.0, 4.0, 9.0], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [2.0, 5.0, 1.0, 2.0], |
| | [1.0, 5.0, 1.0, 2.0], |
| | [6.0, 1.0, 9.0, 2.0], |
| | [7.0, 2.0, 1.0, 2.0], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [3.0, 4.0, 1.0, 2.0], |
| | [6.0, 3.0, 1.0, 2.0], |
| | [3.0, 5.0, 4.0, 1.0], |
| | [5.0, 7.0, 9.0, 7.0], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [1.0, 1.0, 1.0, 2.0], |
| | [1.0, 1.0, 1.0, 2.0], |
| | [2.0, 0.0, 4.0, 6.0], |
| | [7.0, 0.0, 3.0, 2.0], |
| | ] |
| | ], |
| | ] |
| | ), |
| | ), |
| | atol=0.0001, |
| | rtol=0.0, |
| | ) |
| | assert_expected( |
| | actual_targets, |
| | torch.Tensor( |
| | [ |
| | [0.1375, 0.8125, 0.0250, 0.0250], |
| | [0.0250, 0.0250, 0.8125, 0.1375], |
| | [0.0250, 0.0250, 0.1375, 0.8125], |
| | [0.8125, 0.1375, 0.0250, 0.0250], |
| | ] |
| | ), |
| | atol=0.0001, |
| | rtol=0.0, |
| | ) |
| |
|
| | def test_no_augment(self, inputs, targets): |
| | cutmix = MixUpCutMix( |
| | augment_prob=0, mixup_alpha=1, switch_prob=1, cutmix_alpha=1, classes=4 |
| | ) |
| | actual_images, actual_targets = cutmix(images=inputs, targets=targets) |
| | assert_expected(actual_images, inputs) |
| | assert_expected( |
| | actual_targets, |
| | torch.Tensor( |
| | [ |
| | [0.0250, 0.9250, 0.0250, 0.0250], |
| | [0.0250, 0.0250, 0.9250, 0.0250], |
| | [0.0250, 0.0250, 0.0250, 0.9250], |
| | [0.9250, 0.0250, 0.0250, 0.0250], |
| | ] |
| | ), |
| | atol=0.0001, |
| | rtol=0.0, |
| | ) |
| |
|
| |
|
| | @pytest.fixture |
| | def wav(): |
| | with open(get_asset_path("kaldi_file_8000.wav"), "rb") as f: |
| | wav = f.read() |
| | out = torch.frombuffer(wav, dtype=torch.uint8) |
| | return out |
| |
|
| |
|
| | class TestAudioEvalTransform: |
| | @pytest.fixture |
| | def transform(self): |
| | return AudioEvalTransform() |
| |
|
| | @skip_if_no_ffmpeg() |
| | def test_transform(self, transform, wav): |
| | actual = transform(wav) |
| | assert_expected(actual.size(), (1, 1024, 128)) |
| | assert_expected(actual.sum().item(), 52000.8828, atol=0.0001, rtol=0.0) |
| |
|
| | @skip_if_no_ffmpeg() |
| | def test_transform_list(self, transform, wav): |
| | actual = transform([wav]) |
| | assert_expected(actual.size(), (1, 1, 1024, 128)) |
| | assert_expected(actual.sum().item(), 52000.8828, atol=0.0001, rtol=0.0) |
| |
|
| |
|
| | class TestAudioPretrainTransform: |
| | @pytest.fixture(autouse=True) |
| | def set_seed(self): |
| | np.random.seed(0) |
| |
|
| | @pytest.fixture |
| | def transform(self): |
| | return AudioPretrainTransform() |
| |
|
| | @skip_if_no_ffmpeg() |
| | def test_transform(self, transform, wav): |
| | actual = transform(wav) |
| | assert_expected(actual.size(), (1, 1024, 128)) |
| | assert_expected(actual.sum().item(), 52072.4531, atol=0.0001, rtol=0.0001) |
| |
|
| | @skip_if_no_ffmpeg() |
| | def test_transform_list(self, transform, wav): |
| | actual = transform([wav]) |
| | assert_expected(actual.size(), (1, 1, 1024, 128)) |
| | assert_expected(actual.sum().item(), 52072.4531, atol=0.0001, rtol=0.0001) |
| |
|
| |
|
| | class TestAudioFinetuneTransform: |
| | @pytest.fixture(autouse=True) |
| | def set_seed(self): |
| | set_rng_seed(0) |
| | np.random.seed(0) |
| |
|
| | @pytest.fixture |
| | def transform(self): |
| | return AudioFineTuneTransform() |
| |
|
| | @skip_if_no_ffmpeg() |
| | def test_transform(self, transform, wav): |
| | actual = transform(wav) |
| | assert_expected(actual.size(), (1, 1024, 128)) |
| | assert_expected(actual.sum().item(), 53656.75, atol=0.0001, rtol=0.0001) |
| |
|
| | @skip_if_no_ffmpeg() |
| | def test_transform_list(self, transform, wav): |
| | actual = transform([wav]) |
| | assert_expected(actual.size(), (1, 1, 1024, 128)) |
| | assert_expected(actual.sum().item(), 53656.75, atol=0.0001, rtol=0.0001) |
| |
|
| | @skip_if_no_ffmpeg() |
| | def test_transform_with_mixup(self, transform, wav): |
| | with open(get_asset_path("sinewave.wav"), "rb") as f: |
| | bfr = f.read() |
| | mixup_wav = [torch.frombuffer(bfr, dtype=torch.uint8)] |
| | actual = transform(wav, mixup_wav, mix_lambda=0.5) |
| | assert_expected(actual.sum().item(), 54631.046875, atol=0.0001, rtol=0.0001) |
| |
|
| |
|
| | class TestRandAug: |
| | @pytest.fixture(autouse=True) |
| | def set_seed(self): |
| | torch.manual_seed(0) |
| | np.random.seed(0) |
| | random.seed(0) |
| |
|
| | @pytest.fixture |
| | def inputs(self): |
| | img = torch.Tensor( |
| | [ |
| | [ |
| | [1.0, 2.0, 1.0, 2.0], |
| | [2.0, 1.0, 1.0, 2.0], |
| | [2.0, 2.0, 1.0, 0.0], |
| | [7.0, 5.0, 4.0, 9.0], |
| | ], |
| | [ |
| | [1.0, 2.0, 1.0, 2.0], |
| | [2.0, 1.0, 1.0, 2.0], |
| | [2.0, 2.0, 1.0, 0.0], |
| | [7.0, 5.0, 4.0, 9.0], |
| | ], |
| | [ |
| | [1.0, 2.0, 1.0, 2.0], |
| | [2.0, 1.0, 1.0, 2.0], |
| | [2.0, 2.0, 1.0, 0.0], |
| | [7.0, 5.0, 4.0, 9.0], |
| | ], |
| | ] |
| | ) |
| | return transforms.ToPILImage()(img) |
| |
|
| | def test_all_augment(self, inputs): |
| | aug = RandAug(num_ops=15, prob=1, sample_with_replacement=False) |
| | actual_img = aug(inputs) |
| | torch.testing.assert_close( |
| | transforms.ToTensor()(actual_img), |
| | torch.Tensor( |
| | [ |
| | [ |
| | [0.7490, 0.0118, 0.3373, 0.8745], |
| | [0.5333, 0.0667, 0.4667, 0.0275], |
| | [0.5098, 0.0510, 0.5608, 0.6627], |
| | [0.8039, 0.0000, 0.4863, 0.4863], |
| | ], |
| | [ |
| | [0.7843, 0.0118, 0.3294, 0.8784], |
| | [0.5490, 0.0627, 0.4667, 0.0235], |
| | [0.5294, 0.0471, 0.5647, 0.6941], |
| | [0.8392, 0.0000, 0.4549, 0.4549], |
| | ], |
| | [ |
| | [0.2000, 0.0118, 0.3294, 0.7882], |
| | [0.6667, 0.0667, 0.4667, 0.0980], |
| | [0.4784, 0.0471, 0.5608, 0.1098], |
| | [0.1020, 0.0000, 0.4078, 0.4078], |
| | ], |
| | ] |
| | ), |
| | atol=0.0001, |
| | rtol=0.0, |
| | ) |
| |
|