Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import string
from contextlib import contextmanager
from pathlib import Path
from unittest import mock
import numpy as np
import pytest
import torch
from nemo.collections.common.parts.preprocessing.manifest import get_full_path, is_tarred_dataset
from nemo.collections.common.parts.utils import flatten, mask_sequence_tensor
class TestListUtils:
@pytest.mark.unit
def test_flatten(self):
"""Test flattening an iterable with different values: str, bool, int, float, complex."""
test_cases = []
test_cases.append({'input': ['aa', 'bb', 'cc'], 'golden': ['aa', 'bb', 'cc']})
test_cases.append({'input': ['aa', ['bb', 'cc']], 'golden': ['aa', 'bb', 'cc']})
test_cases.append({'input': ['aa', [['bb'], [['cc']]]], 'golden': ['aa', 'bb', 'cc']})
test_cases.append({'input': ['aa', [[1, 2], [[3]], 4]], 'golden': ['aa', 1, 2, 3, 4]})
test_cases.append({'input': [True, [2.5, 2.0 + 1j]], 'golden': [True, 2.5, 2.0 + 1j]})
for n, test_case in enumerate(test_cases):
assert flatten(test_case['input']) == test_case['golden'], f'Test case {n} failed!'
class TestMaskSequenceTensor:
@pytest.mark.unit
@pytest.mark.parametrize('ndim', [2, 3, 4, 5])
def test_mask_sequence_tensor(self, ndim: int):
"""Test masking a tensor based on the provided length."""
num_examples = 20
max_batch_size = 10
max_max_len = 30
for n in range(num_examples):
batch_size = np.random.randint(low=1, high=max_batch_size)
max_len = np.random.randint(low=1, high=max_max_len)
if ndim > 2:
tensor_shape = (batch_size,) + tuple(torch.randint(1, 30, (ndim - 2,))) + (max_len,)
else:
tensor_shape = (batch_size, max_len)
tensor = torch.randn(tensor_shape)
lengths = torch.randint(low=1, high=max_len + 1, size=(batch_size,))
if ndim <= 4:
masked_tensor = mask_sequence_tensor(tensor=tensor, lengths=lengths)
for b, l in enumerate(lengths):
assert torch.equal(masked_tensor[b, ..., :l], tensor[b, ..., :l]), f'Failed for example {n}'
assert torch.all(masked_tensor[b, ..., l:] == 0.0), f'Failed for example {n}'
else:
# Currently, supporting only up to 4D tensors
with pytest.raises(ValueError):
mask_sequence_tensor(tensor=tensor, lengths=lengths)
class TestPreprocessingUtils:
@pytest.mark.unit
def test_get_full_path_local(self, tmpdir):
"""Test with local paths"""
# Create a few files
num_files = 10
audio_files_relative_path = [f'file_{n}.test' for n in range(num_files)]
audio_files_absolute_path = [os.path.join(tmpdir, a_file_rel) for a_file_rel in audio_files_relative_path]
data_dir = tmpdir
manifest_file = os.path.join(data_dir, 'manifest.json')
# Context manager to create dummy files
@contextmanager
def create_files(paths):
# Create files
for a_file in paths:
Path(a_file).touch()
yield
# Remove files
for a_file in paths:
Path(a_file).unlink()
# 1) Test with absolute paths and while files don't exist.
# Note: it's still expected the path will be resolved correctly, since it will be
# expanded using manifest_file.parent or data_dir and relative path.
# - single file
for n in range(num_files):
assert (
get_full_path(audio_files_absolute_path[n], manifest_file=manifest_file)
== audio_files_absolute_path[n]
)
assert get_full_path(audio_files_absolute_path[n], data_dir=data_dir) == audio_files_absolute_path[n]
# - all files in a list
assert get_full_path(audio_files_absolute_path, manifest_file=manifest_file) == audio_files_absolute_path
assert get_full_path(audio_files_absolute_path, data_dir=data_dir) == audio_files_absolute_path
# 2) Test with absolute paths and existing files.
with create_files(audio_files_absolute_path):
# - single file
for n in range(num_files):
assert (
get_full_path(audio_files_absolute_path[n], manifest_file=manifest_file)
== audio_files_absolute_path[n]
)
assert get_full_path(audio_files_absolute_path[n], data_dir=data_dir) == audio_files_absolute_path[n]
# - all files in a list
assert get_full_path(audio_files_absolute_path, manifest_file=manifest_file) == audio_files_absolute_path
assert get_full_path(audio_files_absolute_path, data_dir=data_dir) == audio_files_absolute_path
# 3) Test with relative paths while files don't exist.
# This is a situation we may have with a tarred dataset.
# In this case, we expect to return the relative path.
# - single file
for n in range(num_files):
assert (
get_full_path(audio_files_relative_path[n], manifest_file=manifest_file)
== audio_files_relative_path[n]
)
assert get_full_path(audio_files_relative_path[n], data_dir=data_dir) == audio_files_relative_path[n]
# - all files in a list
assert get_full_path(audio_files_relative_path, manifest_file=manifest_file) == audio_files_relative_path
assert get_full_path(audio_files_relative_path, data_dir=data_dir) == audio_files_relative_path
# 4) Test with relative paths and existing files.
# In this case, we expect to return the absolute path.
with create_files(audio_files_absolute_path):
# - single file
for n in range(num_files):
assert (
get_full_path(audio_files_relative_path[n], manifest_file=manifest_file)
== audio_files_absolute_path[n]
)
assert get_full_path(audio_files_relative_path[n], data_dir=data_dir) == audio_files_absolute_path[n]
# - all files in a list
assert get_full_path(audio_files_relative_path, manifest_file=manifest_file) == audio_files_absolute_path
assert get_full_path(audio_files_relative_path, data_dir=data_dir) == audio_files_absolute_path
# 5) Test with relative paths and existing files, and the filepaths start with './'.
# In this case, we expect to return the same relative path.
curr_dir = os.path.dirname(os.path.abspath(__file__))
audio_files_relative_path_curr = [f'./file_{n}.test' for n in range(num_files)]
with create_files(audio_files_relative_path_curr):
# - single file
for n in range(num_files):
assert os.path.isfile(audio_files_relative_path_curr[n]) == True
assert (
get_full_path(audio_files_relative_path_curr[n], manifest_file=manifest_file)
== audio_files_relative_path_curr[n]
)
assert (
get_full_path(audio_files_relative_path_curr[n], data_dir=curr_dir)
== audio_files_relative_path_curr[n]
)
# - all files in a list
assert (
get_full_path(audio_files_relative_path_curr, manifest_file=manifest_file)
== audio_files_relative_path_curr
)
assert get_full_path(audio_files_relative_path_curr, data_dir=curr_dir) == audio_files_relative_path_curr
@pytest.mark.unit
def test_get_full_path_ais(self, tmpdir):
"""Test with paths on AIStore."""
# Create a few files
num_files = 10
audio_files_relative_path = [f'file_{n}.test' for n in range(num_files)]
audio_files_cache_path = [os.path.join(tmpdir, a_file_rel) for a_file_rel in audio_files_relative_path]
ais_data_dir = 'ais://test'
ais_manifest_file = os.path.join(ais_data_dir, 'manifest.json')
# Context manager to create dummy files
@contextmanager
def create_files(paths):
# Create files
for a_file in paths:
Path(a_file).touch()
yield
# Remove files
for a_file in paths:
Path(a_file).unlink()
# Simulate caching in local tmpdir
def datastore_path_to_cache_path_in_tmpdir(path):
rel_path = os.path.relpath(path, start=os.path.dirname(ais_manifest_file))
if rel_path in audio_files_relative_path:
idx = audio_files_relative_path.index(rel_path)
return audio_files_cache_path[idx]
else:
raise ValueError(f'Unexpected path {path}')
with mock.patch(
'nemo.collections.common.parts.preprocessing.manifest.get_datastore_object',
datastore_path_to_cache_path_in_tmpdir,
):
# Test with relative paths and existing cached files.
# We expect to return the absolute path in the local cache.
with create_files(audio_files_cache_path):
# - single file
for n in range(num_files):
assert (
get_full_path(audio_files_relative_path[n], manifest_file=ais_manifest_file)
== audio_files_cache_path[n]
)
assert (
get_full_path(audio_files_relative_path[n], data_dir=ais_data_dir) == audio_files_cache_path[n]
)
# - all files in a list
assert (
get_full_path(audio_files_relative_path, manifest_file=ais_manifest_file) == audio_files_cache_path
)
assert get_full_path(audio_files_relative_path, data_dir=ais_data_dir) == audio_files_cache_path
@pytest.mark.unit
def test_get_full_path_ais_no_cache(self):
"""Test with paths on AIStore."""
# Create a few files
num_files = 10
audio_files_relative_path = [f'file_{n}.test' for n in range(num_files)]
ais_data_dir = 'ais://test'
ais_manifest_file = os.path.join(ais_data_dir, 'manifest.json')
audio_files_absolute_path = [os.path.join(ais_data_dir, rel_path) for rel_path in audio_files_relative_path]
# Test with only relative paths.
# We expect to return the absolute path in the AIStore when force_cache is set to False.
# This is used in Lhotse Dataloaders.
for n in range(num_files):
assert (
get_full_path(audio_files_relative_path[n], manifest_file=ais_manifest_file, force_cache=False)
== audio_files_absolute_path[n]
)
assert (
get_full_path(audio_files_relative_path[n], data_dir=ais_data_dir, force_cache=False)
== audio_files_absolute_path[n]
)
# - all files in a list
assert (
get_full_path(audio_files_relative_path, manifest_file=ais_manifest_file, force_cache=False)
== audio_files_absolute_path
)
assert (
get_full_path(audio_files_relative_path, data_dir=ais_data_dir, force_cache=False)
== audio_files_absolute_path
)
@pytest.mark.unit
def test_get_full_path_audio_file_len_limit(self):
"""Test with audio_file_len_limit.
Currently, get_full_path will always return the input path when the length
is over audio_file_len_limit, independend of whether the file exists.
"""
# Create a few files
num_examples = 10
rand_chars = list(string.ascii_uppercase + string.ascii_lowercase + string.digits + os.sep)
rand_name = lambda n: ''.join(np.random.choice(rand_chars, size=n))
for audio_file_len_limit in [255, 300]:
for n in range(num_examples):
path_length = np.random.randint(low=audio_file_len_limit, high=350)
audio_file_path = str(Path(rand_name(path_length)))
assert (
get_full_path(audio_file_path, audio_file_len_limit=audio_file_len_limit) == audio_file_path
), f'Limit {audio_file_len_limit}: expected {audio_file_path} to be returned.'
audio_file_path_with_user = os.path.join('~', audio_file_path)
audio_file_path_with_user_expected = os.path.expanduser(audio_file_path_with_user)
assert (
get_full_path(audio_file_path_with_user, audio_file_len_limit=audio_file_len_limit)
== audio_file_path_with_user_expected
), f'Limit {audio_file_len_limit}: expected {audio_file_path_with_user_expected} to be returned.'
@pytest.mark.unit
def test_get_full_path_invalid_type(self):
"""Make sure exceptions are raised when audio_file is not a string or a list of strings."""
with pytest.raises(ValueError, match="Unexpected audio_file type"):
get_full_path(1)
with pytest.raises(ValueError, match="Unexpected audio_file type"):
get_full_path(('a', 'b', 'c'))
with pytest.raises(ValueError, match="Unexpected audio_file type"):
get_full_path({'a': 1, 'b': 2, 'c': 3})
with pytest.raises(ValueError, match="Unexpected audio_file type"):
get_full_path([1, 2, 3])
@pytest.mark.unit
def test_get_full_path_invalid_relative_path(self):
"""Make sure exceptions are raised when audio_file is a relative path and
manifest is not provided or both manifest and data dir are provided simultaneously.
"""
with pytest.raises(ValueError, match="Use either manifest_file or data_dir"):
# Using a relative path without manifest_file or data_dir is not allowed
get_full_path('relative/path')
with pytest.raises(ValueError, match="Parameters manifest_file and data_dir cannot be used simultaneously."):
# Using a relative path without both manifest_file or data_dir is not allowed
get_full_path('relative/path', manifest_file='/manifest_dir/file.json', data_dir='/data/dir')
@pytest.mark.unit
def test_is_tarred_dataset(self):
# 1) is tarred dataset
assert is_tarred_dataset("_file_1.wav", "tarred_audio_manifest.json")
assert is_tarred_dataset("_file_1.wav", "./sharded_manifests/manifest_1.json")
# 2) is not tarred dataset
assert not is_tarred_dataset("./file_1.wav", "audio_manifest.json")
assert not is_tarred_dataset("./file_1.wav", "./sharded_manifests/manifest_test.json")
assert not is_tarred_dataset("file_1.wav", "audio_manifest.json")
assert not is_tarred_dataset("file_1.wav", "./sharded_manifests/manifest_test.json")
assert not is_tarred_dataset("/data/file_1.wav", "audio_manifest.json")
assert not is_tarred_dataset("/data/file_1.wav", "./sharded_manifests/manifest_test.json")
assert not is_tarred_dataset("_file_1.wav", "audio_manifest.json")
assert not is_tarred_dataset("_file_1.wav", "./sharded_manifests/manifest_test.json")
# 3) no manifest file, treated as non-tarred dataset
assert not is_tarred_dataset("_file_1.wav", None)