# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import string from contextlib import contextmanager from pathlib import Path from unittest import mock import numpy as np import pytest import torch from nemo.collections.common.parts.preprocessing.manifest import get_full_path, is_tarred_dataset from nemo.collections.common.parts.utils import flatten, mask_sequence_tensor class TestListUtils: @pytest.mark.unit def test_flatten(self): """Test flattening an iterable with different values: str, bool, int, float, complex.""" test_cases = [] test_cases.append({'input': ['aa', 'bb', 'cc'], 'golden': ['aa', 'bb', 'cc']}) test_cases.append({'input': ['aa', ['bb', 'cc']], 'golden': ['aa', 'bb', 'cc']}) test_cases.append({'input': ['aa', [['bb'], [['cc']]]], 'golden': ['aa', 'bb', 'cc']}) test_cases.append({'input': ['aa', [[1, 2], [[3]], 4]], 'golden': ['aa', 1, 2, 3, 4]}) test_cases.append({'input': [True, [2.5, 2.0 + 1j]], 'golden': [True, 2.5, 2.0 + 1j]}) for n, test_case in enumerate(test_cases): assert flatten(test_case['input']) == test_case['golden'], f'Test case {n} failed!' class TestMaskSequenceTensor: @pytest.mark.unit @pytest.mark.parametrize('ndim', [2, 3, 4, 5]) def test_mask_sequence_tensor(self, ndim: int): """Test masking a tensor based on the provided length.""" num_examples = 20 max_batch_size = 10 max_max_len = 30 for n in range(num_examples): batch_size = np.random.randint(low=1, high=max_batch_size) max_len = np.random.randint(low=1, high=max_max_len) if ndim > 2: tensor_shape = (batch_size,) + tuple(torch.randint(1, 30, (ndim - 2,))) + (max_len,) else: tensor_shape = (batch_size, max_len) tensor = torch.randn(tensor_shape) lengths = torch.randint(low=1, high=max_len + 1, size=(batch_size,)) if ndim <= 4: masked_tensor = mask_sequence_tensor(tensor=tensor, lengths=lengths) for b, l in enumerate(lengths): assert torch.equal(masked_tensor[b, ..., :l], tensor[b, ..., :l]), f'Failed for example {n}' assert torch.all(masked_tensor[b, ..., l:] == 0.0), f'Failed for example {n}' else: # Currently, supporting only up to 4D tensors with pytest.raises(ValueError): mask_sequence_tensor(tensor=tensor, lengths=lengths) class TestPreprocessingUtils: @pytest.mark.unit def test_get_full_path_local(self, tmpdir): """Test with local paths""" # Create a few files num_files = 10 audio_files_relative_path = [f'file_{n}.test' for n in range(num_files)] audio_files_absolute_path = [os.path.join(tmpdir, a_file_rel) for a_file_rel in audio_files_relative_path] data_dir = tmpdir manifest_file = os.path.join(data_dir, 'manifest.json') # Context manager to create dummy files @contextmanager def create_files(paths): # Create files for a_file in paths: Path(a_file).touch() yield # Remove files for a_file in paths: Path(a_file).unlink() # 1) Test with absolute paths and while files don't exist. # Note: it's still expected the path will be resolved correctly, since it will be # expanded using manifest_file.parent or data_dir and relative path. # - single file for n in range(num_files): assert ( get_full_path(audio_files_absolute_path[n], manifest_file=manifest_file) == audio_files_absolute_path[n] ) assert get_full_path(audio_files_absolute_path[n], data_dir=data_dir) == audio_files_absolute_path[n] # - all files in a list assert get_full_path(audio_files_absolute_path, manifest_file=manifest_file) == audio_files_absolute_path assert get_full_path(audio_files_absolute_path, data_dir=data_dir) == audio_files_absolute_path # 2) Test with absolute paths and existing files. with create_files(audio_files_absolute_path): # - single file for n in range(num_files): assert ( get_full_path(audio_files_absolute_path[n], manifest_file=manifest_file) == audio_files_absolute_path[n] ) assert get_full_path(audio_files_absolute_path[n], data_dir=data_dir) == audio_files_absolute_path[n] # - all files in a list assert get_full_path(audio_files_absolute_path, manifest_file=manifest_file) == audio_files_absolute_path assert get_full_path(audio_files_absolute_path, data_dir=data_dir) == audio_files_absolute_path # 3) Test with relative paths while files don't exist. # This is a situation we may have with a tarred dataset. # In this case, we expect to return the relative path. # - single file for n in range(num_files): assert ( get_full_path(audio_files_relative_path[n], manifest_file=manifest_file) == audio_files_relative_path[n] ) assert get_full_path(audio_files_relative_path[n], data_dir=data_dir) == audio_files_relative_path[n] # - all files in a list assert get_full_path(audio_files_relative_path, manifest_file=manifest_file) == audio_files_relative_path assert get_full_path(audio_files_relative_path, data_dir=data_dir) == audio_files_relative_path # 4) Test with relative paths and existing files. # In this case, we expect to return the absolute path. with create_files(audio_files_absolute_path): # - single file for n in range(num_files): assert ( get_full_path(audio_files_relative_path[n], manifest_file=manifest_file) == audio_files_absolute_path[n] ) assert get_full_path(audio_files_relative_path[n], data_dir=data_dir) == audio_files_absolute_path[n] # - all files in a list assert get_full_path(audio_files_relative_path, manifest_file=manifest_file) == audio_files_absolute_path assert get_full_path(audio_files_relative_path, data_dir=data_dir) == audio_files_absolute_path # 5) Test with relative paths and existing files, and the filepaths start with './'. # In this case, we expect to return the same relative path. curr_dir = os.path.dirname(os.path.abspath(__file__)) audio_files_relative_path_curr = [f'./file_{n}.test' for n in range(num_files)] with create_files(audio_files_relative_path_curr): # - single file for n in range(num_files): assert os.path.isfile(audio_files_relative_path_curr[n]) == True assert ( get_full_path(audio_files_relative_path_curr[n], manifest_file=manifest_file) == audio_files_relative_path_curr[n] ) assert ( get_full_path(audio_files_relative_path_curr[n], data_dir=curr_dir) == audio_files_relative_path_curr[n] ) # - all files in a list assert ( get_full_path(audio_files_relative_path_curr, manifest_file=manifest_file) == audio_files_relative_path_curr ) assert get_full_path(audio_files_relative_path_curr, data_dir=curr_dir) == audio_files_relative_path_curr @pytest.mark.unit def test_get_full_path_ais(self, tmpdir): """Test with paths on AIStore.""" # Create a few files num_files = 10 audio_files_relative_path = [f'file_{n}.test' for n in range(num_files)] audio_files_cache_path = [os.path.join(tmpdir, a_file_rel) for a_file_rel in audio_files_relative_path] ais_data_dir = 'ais://test' ais_manifest_file = os.path.join(ais_data_dir, 'manifest.json') # Context manager to create dummy files @contextmanager def create_files(paths): # Create files for a_file in paths: Path(a_file).touch() yield # Remove files for a_file in paths: Path(a_file).unlink() # Simulate caching in local tmpdir def datastore_path_to_cache_path_in_tmpdir(path): rel_path = os.path.relpath(path, start=os.path.dirname(ais_manifest_file)) if rel_path in audio_files_relative_path: idx = audio_files_relative_path.index(rel_path) return audio_files_cache_path[idx] else: raise ValueError(f'Unexpected path {path}') with mock.patch( 'nemo.collections.common.parts.preprocessing.manifest.get_datastore_object', datastore_path_to_cache_path_in_tmpdir, ): # Test with relative paths and existing cached files. # We expect to return the absolute path in the local cache. with create_files(audio_files_cache_path): # - single file for n in range(num_files): assert ( get_full_path(audio_files_relative_path[n], manifest_file=ais_manifest_file) == audio_files_cache_path[n] ) assert ( get_full_path(audio_files_relative_path[n], data_dir=ais_data_dir) == audio_files_cache_path[n] ) # - all files in a list assert ( get_full_path(audio_files_relative_path, manifest_file=ais_manifest_file) == audio_files_cache_path ) assert get_full_path(audio_files_relative_path, data_dir=ais_data_dir) == audio_files_cache_path @pytest.mark.unit def test_get_full_path_ais_no_cache(self): """Test with paths on AIStore.""" # Create a few files num_files = 10 audio_files_relative_path = [f'file_{n}.test' for n in range(num_files)] ais_data_dir = 'ais://test' ais_manifest_file = os.path.join(ais_data_dir, 'manifest.json') audio_files_absolute_path = [os.path.join(ais_data_dir, rel_path) for rel_path in audio_files_relative_path] # Test with only relative paths. # We expect to return the absolute path in the AIStore when force_cache is set to False. # This is used in Lhotse Dataloaders. for n in range(num_files): assert ( get_full_path(audio_files_relative_path[n], manifest_file=ais_manifest_file, force_cache=False) == audio_files_absolute_path[n] ) assert ( get_full_path(audio_files_relative_path[n], data_dir=ais_data_dir, force_cache=False) == audio_files_absolute_path[n] ) # - all files in a list assert ( get_full_path(audio_files_relative_path, manifest_file=ais_manifest_file, force_cache=False) == audio_files_absolute_path ) assert ( get_full_path(audio_files_relative_path, data_dir=ais_data_dir, force_cache=False) == audio_files_absolute_path ) @pytest.mark.unit def test_get_full_path_audio_file_len_limit(self): """Test with audio_file_len_limit. Currently, get_full_path will always return the input path when the length is over audio_file_len_limit, independend of whether the file exists. """ # Create a few files num_examples = 10 rand_chars = list(string.ascii_uppercase + string.ascii_lowercase + string.digits + os.sep) rand_name = lambda n: ''.join(np.random.choice(rand_chars, size=n)) for audio_file_len_limit in [255, 300]: for n in range(num_examples): path_length = np.random.randint(low=audio_file_len_limit, high=350) audio_file_path = str(Path(rand_name(path_length))) assert ( get_full_path(audio_file_path, audio_file_len_limit=audio_file_len_limit) == audio_file_path ), f'Limit {audio_file_len_limit}: expected {audio_file_path} to be returned.' audio_file_path_with_user = os.path.join('~', audio_file_path) audio_file_path_with_user_expected = os.path.expanduser(audio_file_path_with_user) assert ( get_full_path(audio_file_path_with_user, audio_file_len_limit=audio_file_len_limit) == audio_file_path_with_user_expected ), f'Limit {audio_file_len_limit}: expected {audio_file_path_with_user_expected} to be returned.' @pytest.mark.unit def test_get_full_path_invalid_type(self): """Make sure exceptions are raised when audio_file is not a string or a list of strings.""" with pytest.raises(ValueError, match="Unexpected audio_file type"): get_full_path(1) with pytest.raises(ValueError, match="Unexpected audio_file type"): get_full_path(('a', 'b', 'c')) with pytest.raises(ValueError, match="Unexpected audio_file type"): get_full_path({'a': 1, 'b': 2, 'c': 3}) with pytest.raises(ValueError, match="Unexpected audio_file type"): get_full_path([1, 2, 3]) @pytest.mark.unit def test_get_full_path_invalid_relative_path(self): """Make sure exceptions are raised when audio_file is a relative path and manifest is not provided or both manifest and data dir are provided simultaneously. """ with pytest.raises(ValueError, match="Use either manifest_file or data_dir"): # Using a relative path without manifest_file or data_dir is not allowed get_full_path('relative/path') with pytest.raises(ValueError, match="Parameters manifest_file and data_dir cannot be used simultaneously."): # Using a relative path without both manifest_file or data_dir is not allowed get_full_path('relative/path', manifest_file='/manifest_dir/file.json', data_dir='/data/dir') @pytest.mark.unit def test_is_tarred_dataset(self): # 1) is tarred dataset assert is_tarred_dataset("_file_1.wav", "tarred_audio_manifest.json") assert is_tarred_dataset("_file_1.wav", "./sharded_manifests/manifest_1.json") # 2) is not tarred dataset assert not is_tarred_dataset("./file_1.wav", "audio_manifest.json") assert not is_tarred_dataset("./file_1.wav", "./sharded_manifests/manifest_test.json") assert not is_tarred_dataset("file_1.wav", "audio_manifest.json") assert not is_tarred_dataset("file_1.wav", "./sharded_manifests/manifest_test.json") assert not is_tarred_dataset("/data/file_1.wav", "audio_manifest.json") assert not is_tarred_dataset("/data/file_1.wav", "./sharded_manifests/manifest_test.json") assert not is_tarred_dataset("_file_1.wav", "audio_manifest.json") assert not is_tarred_dataset("_file_1.wav", "./sharded_manifests/manifest_test.json") # 3) no manifest file, treated as non-tarred dataset assert not is_tarred_dataset("_file_1.wav", None)