vlm_clone_2 / multimodal /tests /transforms /test_mae_transform.py
tuandunghcmut's picture
Add files using upload-large-folder tool
a1b8c74 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import random
import numpy as np
import pytest
import torch
from PIL import Image
from tests.test_utils import (
assert_expected,
get_asset_path,
set_rng_seed,
skip_if_no_ffmpeg,
)
from torchmultimodal.transforms.mae_transform import (
AudioEvalTransform,
AudioFineTuneTransform,
AudioPretrainTransform,
ImageEvalTransform,
ImagePretrainTransform,
MixUpCutMix,
RandAug,
)
from torchvision import transforms
IMAGE_PATH = "tests/assets/test_image.jpg"
WAV_PATH = "tests/assets/kaldi_file_8000.wav"
MIXUP_WAV_PATH = "tests/assets/sinewave.wav"
@pytest.fixture
def image():
return Image.open(get_asset_path("test_image.jpg"))
class TestImageEvalTransform:
@pytest.fixture
def transform(self):
return ImageEvalTransform(input_size=224)
def test_transform(self, transform, image):
actual = transform(image)
assert_expected(actual.size(), (3, 224, 224))
assert_expected(actual.mean().item(), -0.4967, atol=0.0001, rtol=0.0)
def test_transform_list(self, transform, image):
actual = transform([image])
assert_expected(actual.size(), (1, 3, 224, 224))
assert_expected(actual.mean().item(), -0.4967, atol=0.0001, rtol=0.0)
class TestImagePretrainTransform:
@pytest.fixture
def transform(self):
return ImagePretrainTransform(input_size=224)
@pytest.fixture(autouse=True)
def set_seed(self):
set_rng_seed(0)
def test_transform(self, transform, image):
actual = transform(image)
assert_expected(actual.size(), (3, 224, 224))
assert_expected(actual.mean().item(), -0.4625, atol=0.0001, rtol=0.0)
def test_transform_list(self, transform, image):
actual = transform([image])
assert_expected(actual.size(), (1, 3, 224, 224))
assert_expected(actual.mean().item(), -0.4625, atol=0.0001, rtol=0.0)
class TestMixup:
@pytest.fixture(autouse=True)
def set_seed(self):
torch.manual_seed(0)
np.random.seed(0)
@pytest.fixture
def inputs(self):
return torch.Tensor(
[
[
[
[1.0, 2.0, 1.0, 2.0],
[2.0, 1.0, 1.0, 2.0],
[2.0, 2.0, 1.0, 0.0],
[7.0, 5.0, 4.0, 9.0],
]
],
[
[
[2.0, 5.0, 1.0, 2.0],
[1.0, 5.0, 1.0, 2.0],
[3.0, 1.0, 9.0, 2.0],
[5.0, 2.0, 1.0, 2.0],
]
],
[
[
[3.0, 4.0, 1.0, 2.0],
[6.0, 3.0, 1.0, 2.0],
[6.0, 5.0, 4.0, 1.0],
[7.0, 7.0, 9.0, 7.0],
]
],
[
[
[1.0, 1.0, 1.0, 2.0],
[1.0, 1.0, 1.0, 2.0],
[8.0, 0.0, 4.0, 6.0],
[0.0, 0.0, 3.0, 2.0],
]
],
]
)
@pytest.fixture
def targets(self):
return torch.Tensor([1, 2, 3, 0]).to(dtype=torch.long)
def test_mixup(self, inputs, targets):
mixup = MixUpCutMix(
augment_prob=1, mixup_alpha=1, switch_prob=0, cutmix_alpha=0, classes=4
)
actual_images, actual_targets = mixup(images=inputs, targets=targets)
assert_expected(
actual_images,
torch.Tensor(
(
[
[
[
[1.0000, 1.5626, 1.0000, 2.0000],
[1.5626, 1.0000, 1.0000, 2.0000],
[4.6245, 1.1252, 2.3123, 2.6245],
[3.9381, 2.8129, 3.5626, 5.9381],
]
],
[
[
[2.4374, 4.5626, 1.0000, 2.0000],
[3.1871, 4.1252, 1.0000, 2.0000],
[4.3123, 2.7497, 6.8129, 1.5626],
[5.8748, 4.1871, 4.4993, 4.1871],
]
],
[
[
[2.5626, 4.4374, 1.0000, 2.0000],
[3.8129, 3.8748, 1.0000, 2.0000],
[4.6877, 3.2503, 6.1871, 1.4374],
[6.1252, 4.8129, 5.5007, 4.8129],
]
],
[
[
[1.0000, 1.4374, 1.0000, 2.0000],
[1.4374, 1.0000, 1.0000, 2.0000],
[5.3755, 0.8748, 2.6877, 3.3755],
[3.0619, 2.1871, 3.4374, 5.0619],
]
],
]
)
),
atol=0.0001,
rtol=0.0,
)
assert_expected(
actual_targets,
torch.Tensor(
[
[0.4187, 0.5313, 0.0250, 0.0250],
[0.0250, 0.0250, 0.5313, 0.4187],
[0.0250, 0.0250, 0.4187, 0.5313],
[0.5313, 0.4187, 0.0250, 0.0250],
]
),
atol=0.0001,
rtol=0.0,
)
def test_cutmix(self, inputs, targets):
cutmix = MixUpCutMix(
augment_prob=1, mixup_alpha=1, switch_prob=1, cutmix_alpha=1, classes=4
)
actual_images, actual_targets = cutmix(images=inputs, targets=targets)
assert_expected(
actual_images,
torch.Tensor(
(
[
[
[
[1.0, 2.0, 1.0, 2.0],
[2.0, 1.0, 1.0, 2.0],
[8.0, 2.0, 1.0, 0.0],
[0.0, 5.0, 4.0, 9.0],
]
],
[
[
[2.0, 5.0, 1.0, 2.0],
[1.0, 5.0, 1.0, 2.0],
[6.0, 1.0, 9.0, 2.0],
[7.0, 2.0, 1.0, 2.0],
]
],
[
[
[3.0, 4.0, 1.0, 2.0],
[6.0, 3.0, 1.0, 2.0],
[3.0, 5.0, 4.0, 1.0],
[5.0, 7.0, 9.0, 7.0],
]
],
[
[
[1.0, 1.0, 1.0, 2.0],
[1.0, 1.0, 1.0, 2.0],
[2.0, 0.0, 4.0, 6.0],
[7.0, 0.0, 3.0, 2.0],
]
],
]
),
),
atol=0.0001,
rtol=0.0,
)
assert_expected(
actual_targets,
torch.Tensor(
[
[0.1375, 0.8125, 0.0250, 0.0250],
[0.0250, 0.0250, 0.8125, 0.1375],
[0.0250, 0.0250, 0.1375, 0.8125],
[0.8125, 0.1375, 0.0250, 0.0250],
]
),
atol=0.0001,
rtol=0.0,
)
def test_no_augment(self, inputs, targets):
cutmix = MixUpCutMix(
augment_prob=0, mixup_alpha=1, switch_prob=1, cutmix_alpha=1, classes=4
)
actual_images, actual_targets = cutmix(images=inputs, targets=targets)
assert_expected(actual_images, inputs)
assert_expected(
actual_targets,
torch.Tensor(
[
[0.0250, 0.9250, 0.0250, 0.0250],
[0.0250, 0.0250, 0.9250, 0.0250],
[0.0250, 0.0250, 0.0250, 0.9250],
[0.9250, 0.0250, 0.0250, 0.0250],
]
),
atol=0.0001,
rtol=0.0,
)
@pytest.fixture
def wav():
with open(get_asset_path("kaldi_file_8000.wav"), "rb") as f:
wav = f.read()
out = torch.frombuffer(wav, dtype=torch.uint8)
return out
class TestAudioEvalTransform:
@pytest.fixture
def transform(self):
return AudioEvalTransform()
@skip_if_no_ffmpeg()
def test_transform(self, transform, wav):
actual = transform(wav)
assert_expected(actual.size(), (1, 1024, 128))
assert_expected(actual.sum().item(), 52000.8828, atol=0.0001, rtol=0.0)
@skip_if_no_ffmpeg()
def test_transform_list(self, transform, wav):
actual = transform([wav])
assert_expected(actual.size(), (1, 1, 1024, 128))
assert_expected(actual.sum().item(), 52000.8828, atol=0.0001, rtol=0.0)
class TestAudioPretrainTransform:
@pytest.fixture(autouse=True)
def set_seed(self):
np.random.seed(0)
@pytest.fixture
def transform(self):
return AudioPretrainTransform()
@skip_if_no_ffmpeg()
def test_transform(self, transform, wav):
actual = transform(wav)
assert_expected(actual.size(), (1, 1024, 128))
assert_expected(actual.sum().item(), 52072.4531, atol=0.0001, rtol=0.0001)
@skip_if_no_ffmpeg()
def test_transform_list(self, transform, wav):
actual = transform([wav])
assert_expected(actual.size(), (1, 1, 1024, 128))
assert_expected(actual.sum().item(), 52072.4531, atol=0.0001, rtol=0.0001)
class TestAudioFinetuneTransform:
@pytest.fixture(autouse=True)
def set_seed(self):
set_rng_seed(0)
np.random.seed(0)
@pytest.fixture
def transform(self):
return AudioFineTuneTransform()
@skip_if_no_ffmpeg()
def test_transform(self, transform, wav):
actual = transform(wav)
assert_expected(actual.size(), (1, 1024, 128))
assert_expected(actual.sum().item(), 53656.75, atol=0.0001, rtol=0.0001)
@skip_if_no_ffmpeg()
def test_transform_list(self, transform, wav):
actual = transform([wav])
assert_expected(actual.size(), (1, 1, 1024, 128))
assert_expected(actual.sum().item(), 53656.75, atol=0.0001, rtol=0.0001)
@skip_if_no_ffmpeg()
def test_transform_with_mixup(self, transform, wav):
with open(get_asset_path("sinewave.wav"), "rb") as f:
bfr = f.read()
mixup_wav = [torch.frombuffer(bfr, dtype=torch.uint8)]
actual = transform(wav, mixup_wav, mix_lambda=0.5)
assert_expected(actual.sum().item(), 54631.046875, atol=0.0001, rtol=0.0001)
class TestRandAug:
@pytest.fixture(autouse=True)
def set_seed(self):
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
@pytest.fixture
def inputs(self):
img = torch.Tensor(
[
[
[1.0, 2.0, 1.0, 2.0],
[2.0, 1.0, 1.0, 2.0],
[2.0, 2.0, 1.0, 0.0],
[7.0, 5.0, 4.0, 9.0],
],
[
[1.0, 2.0, 1.0, 2.0],
[2.0, 1.0, 1.0, 2.0],
[2.0, 2.0, 1.0, 0.0],
[7.0, 5.0, 4.0, 9.0],
],
[
[1.0, 2.0, 1.0, 2.0],
[2.0, 1.0, 1.0, 2.0],
[2.0, 2.0, 1.0, 0.0],
[7.0, 5.0, 4.0, 9.0],
],
]
)
return transforms.ToPILImage()(img)
def test_all_augment(self, inputs):
aug = RandAug(num_ops=15, prob=1, sample_with_replacement=False)
actual_img = aug(inputs)
torch.testing.assert_close(
transforms.ToTensor()(actual_img),
torch.Tensor(
[
[
[0.7490, 0.0118, 0.3373, 0.8745],
[0.5333, 0.0667, 0.4667, 0.0275],
[0.5098, 0.0510, 0.5608, 0.6627],
[0.8039, 0.0000, 0.4863, 0.4863],
],
[
[0.7843, 0.0118, 0.3294, 0.8784],
[0.5490, 0.0627, 0.4667, 0.0235],
[0.5294, 0.0471, 0.5647, 0.6941],
[0.8392, 0.0000, 0.4549, 0.4549],
],
[
[0.2000, 0.0118, 0.3294, 0.7882],
[0.6667, 0.0667, 0.4667, 0.0980],
[0.4784, 0.0471, 0.5608, 0.1098],
[0.1020, 0.0000, 0.4078, 0.4078],
],
]
),
atol=0.0001,
rtol=0.0,
)