File size: 16,157 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import string
from contextlib import contextmanager
from pathlib import Path
from unittest import mock

import numpy as np
import pytest
import torch

from nemo.collections.common.parts.preprocessing.manifest import get_full_path, is_tarred_dataset
from nemo.collections.common.parts.utils import flatten, mask_sequence_tensor


class TestListUtils:
    @pytest.mark.unit
    def test_flatten(self):
        """Test flattening an iterable with different values: str, bool, int, float, complex."""
        test_cases = []
        test_cases.append({'input': ['aa', 'bb', 'cc'], 'golden': ['aa', 'bb', 'cc']})
        test_cases.append({'input': ['aa', ['bb', 'cc']], 'golden': ['aa', 'bb', 'cc']})
        test_cases.append({'input': ['aa', [['bb'], [['cc']]]], 'golden': ['aa', 'bb', 'cc']})
        test_cases.append({'input': ['aa', [[1, 2], [[3]], 4]], 'golden': ['aa', 1, 2, 3, 4]})
        test_cases.append({'input': [True, [2.5, 2.0 + 1j]], 'golden': [True, 2.5, 2.0 + 1j]})

        for n, test_case in enumerate(test_cases):
            assert flatten(test_case['input']) == test_case['golden'], f'Test case {n} failed!'


class TestMaskSequenceTensor:
    @pytest.mark.unit
    @pytest.mark.parametrize('ndim', [2, 3, 4, 5])
    def test_mask_sequence_tensor(self, ndim: int):
        """Test masking a tensor based on the provided length."""
        num_examples = 20
        max_batch_size = 10
        max_max_len = 30

        for n in range(num_examples):
            batch_size = np.random.randint(low=1, high=max_batch_size)
            max_len = np.random.randint(low=1, high=max_max_len)

            if ndim > 2:
                tensor_shape = (batch_size,) + tuple(torch.randint(1, 30, (ndim - 2,))) + (max_len,)
            else:
                tensor_shape = (batch_size, max_len)

            tensor = torch.randn(tensor_shape)
            lengths = torch.randint(low=1, high=max_len + 1, size=(batch_size,))

            if ndim <= 4:
                masked_tensor = mask_sequence_tensor(tensor=tensor, lengths=lengths)

                for b, l in enumerate(lengths):
                    assert torch.equal(masked_tensor[b, ..., :l], tensor[b, ..., :l]), f'Failed for example {n}'
                    assert torch.all(masked_tensor[b, ..., l:] == 0.0), f'Failed for example {n}'
            else:
                # Currently, supporting only up to 4D tensors
                with pytest.raises(ValueError):
                    mask_sequence_tensor(tensor=tensor, lengths=lengths)


class TestPreprocessingUtils:
    @pytest.mark.unit
    def test_get_full_path_local(self, tmpdir):
        """Test with local paths"""
        # Create a few files
        num_files = 10

        audio_files_relative_path = [f'file_{n}.test' for n in range(num_files)]
        audio_files_absolute_path = [os.path.join(tmpdir, a_file_rel) for a_file_rel in audio_files_relative_path]

        data_dir = tmpdir
        manifest_file = os.path.join(data_dir, 'manifest.json')

        # Context manager to create dummy files
        @contextmanager
        def create_files(paths):
            # Create files
            for a_file in paths:
                Path(a_file).touch()
            yield
            # Remove files
            for a_file in paths:
                Path(a_file).unlink()

        # 1) Test with absolute paths and while files don't exist.
        # Note: it's still expected the path will be resolved correctly, since it will be
        # expanded using manifest_file.parent or data_dir and relative path.
        # - single file
        for n in range(num_files):
            assert (
                get_full_path(audio_files_absolute_path[n], manifest_file=manifest_file)
                == audio_files_absolute_path[n]
            )
            assert get_full_path(audio_files_absolute_path[n], data_dir=data_dir) == audio_files_absolute_path[n]

        # - all files in a list
        assert get_full_path(audio_files_absolute_path, manifest_file=manifest_file) == audio_files_absolute_path
        assert get_full_path(audio_files_absolute_path, data_dir=data_dir) == audio_files_absolute_path

        # 2) Test with absolute paths and existing files.
        with create_files(audio_files_absolute_path):
            # - single file
            for n in range(num_files):
                assert (
                    get_full_path(audio_files_absolute_path[n], manifest_file=manifest_file)
                    == audio_files_absolute_path[n]
                )
                assert get_full_path(audio_files_absolute_path[n], data_dir=data_dir) == audio_files_absolute_path[n]

            # - all files in a list
            assert get_full_path(audio_files_absolute_path, manifest_file=manifest_file) == audio_files_absolute_path
            assert get_full_path(audio_files_absolute_path, data_dir=data_dir) == audio_files_absolute_path

        # 3) Test with relative paths while files don't exist.
        # This is a situation we may have with a tarred dataset.
        # In this case, we expect to return the relative path.
        # - single file
        for n in range(num_files):
            assert (
                get_full_path(audio_files_relative_path[n], manifest_file=manifest_file)
                == audio_files_relative_path[n]
            )
            assert get_full_path(audio_files_relative_path[n], data_dir=data_dir) == audio_files_relative_path[n]

        # - all files in a list
        assert get_full_path(audio_files_relative_path, manifest_file=manifest_file) == audio_files_relative_path
        assert get_full_path(audio_files_relative_path, data_dir=data_dir) == audio_files_relative_path

        # 4) Test with relative paths and existing files.
        # In this case, we expect to return the absolute path.
        with create_files(audio_files_absolute_path):
            # - single file
            for n in range(num_files):
                assert (
                    get_full_path(audio_files_relative_path[n], manifest_file=manifest_file)
                    == audio_files_absolute_path[n]
                )
                assert get_full_path(audio_files_relative_path[n], data_dir=data_dir) == audio_files_absolute_path[n]

            # - all files in a list
            assert get_full_path(audio_files_relative_path, manifest_file=manifest_file) == audio_files_absolute_path
            assert get_full_path(audio_files_relative_path, data_dir=data_dir) == audio_files_absolute_path

        # 5) Test with relative paths and existing files, and the filepaths start with './'.
        # In this case, we expect to return the same relative path.
        curr_dir = os.path.dirname(os.path.abspath(__file__))
        audio_files_relative_path_curr = [f'./file_{n}.test' for n in range(num_files)]
        with create_files(audio_files_relative_path_curr):
            # - single file
            for n in range(num_files):
                assert os.path.isfile(audio_files_relative_path_curr[n]) == True
                assert (
                    get_full_path(audio_files_relative_path_curr[n], manifest_file=manifest_file)
                    == audio_files_relative_path_curr[n]
                )
                assert (
                    get_full_path(audio_files_relative_path_curr[n], data_dir=curr_dir)
                    == audio_files_relative_path_curr[n]
                )

            # - all files in a list
            assert (
                get_full_path(audio_files_relative_path_curr, manifest_file=manifest_file)
                == audio_files_relative_path_curr
            )
            assert get_full_path(audio_files_relative_path_curr, data_dir=curr_dir) == audio_files_relative_path_curr

    @pytest.mark.unit
    def test_get_full_path_ais(self, tmpdir):
        """Test with paths on AIStore."""
        # Create a few files
        num_files = 10

        audio_files_relative_path = [f'file_{n}.test' for n in range(num_files)]
        audio_files_cache_path = [os.path.join(tmpdir, a_file_rel) for a_file_rel in audio_files_relative_path]

        ais_data_dir = 'ais://test'
        ais_manifest_file = os.path.join(ais_data_dir, 'manifest.json')

        # Context manager to create dummy files
        @contextmanager
        def create_files(paths):
            # Create files
            for a_file in paths:
                Path(a_file).touch()
            yield
            # Remove files
            for a_file in paths:
                Path(a_file).unlink()

        # Simulate caching in local tmpdir
        def datastore_path_to_cache_path_in_tmpdir(path):
            rel_path = os.path.relpath(path, start=os.path.dirname(ais_manifest_file))

            if rel_path in audio_files_relative_path:
                idx = audio_files_relative_path.index(rel_path)
                return audio_files_cache_path[idx]
            else:
                raise ValueError(f'Unexpected path {path}')

        with mock.patch(
            'nemo.collections.common.parts.preprocessing.manifest.get_datastore_object',
            datastore_path_to_cache_path_in_tmpdir,
        ):
            # Test with relative paths and existing cached files.
            # We expect to return the absolute path in the local cache.
            with create_files(audio_files_cache_path):
                # - single file
                for n in range(num_files):
                    assert (
                        get_full_path(audio_files_relative_path[n], manifest_file=ais_manifest_file)
                        == audio_files_cache_path[n]
                    )
                    assert (
                        get_full_path(audio_files_relative_path[n], data_dir=ais_data_dir) == audio_files_cache_path[n]
                    )

                # - all files in a list
                assert (
                    get_full_path(audio_files_relative_path, manifest_file=ais_manifest_file) == audio_files_cache_path
                )
                assert get_full_path(audio_files_relative_path, data_dir=ais_data_dir) == audio_files_cache_path

    @pytest.mark.unit
    def test_get_full_path_ais_no_cache(self):
        """Test with paths on AIStore."""
        # Create a few files
        num_files = 10

        audio_files_relative_path = [f'file_{n}.test' for n in range(num_files)]

        ais_data_dir = 'ais://test'
        ais_manifest_file = os.path.join(ais_data_dir, 'manifest.json')

        audio_files_absolute_path = [os.path.join(ais_data_dir, rel_path) for rel_path in audio_files_relative_path]

        # Test with only relative paths.
        # We expect to return the absolute path in the AIStore when force_cache is set to False.
        # This is used in Lhotse Dataloaders.
        for n in range(num_files):
            assert (
                get_full_path(audio_files_relative_path[n], manifest_file=ais_manifest_file, force_cache=False)
                == audio_files_absolute_path[n]
            )
            assert (
                get_full_path(audio_files_relative_path[n], data_dir=ais_data_dir, force_cache=False)
                == audio_files_absolute_path[n]
            )

        # - all files in a list
        assert (
            get_full_path(audio_files_relative_path, manifest_file=ais_manifest_file, force_cache=False)
            == audio_files_absolute_path
        )
        assert (
            get_full_path(audio_files_relative_path, data_dir=ais_data_dir, force_cache=False)
            == audio_files_absolute_path
        )

    @pytest.mark.unit
    def test_get_full_path_audio_file_len_limit(self):
        """Test with audio_file_len_limit.
        Currently, get_full_path will always return the input path when the length
        is over audio_file_len_limit, independend of whether the file exists.
        """
        # Create a few files
        num_examples = 10
        rand_chars = list(string.ascii_uppercase + string.ascii_lowercase + string.digits + os.sep)
        rand_name = lambda n: ''.join(np.random.choice(rand_chars, size=n))

        for audio_file_len_limit in [255, 300]:
            for n in range(num_examples):
                path_length = np.random.randint(low=audio_file_len_limit, high=350)
                audio_file_path = str(Path(rand_name(path_length)))

                assert (
                    get_full_path(audio_file_path, audio_file_len_limit=audio_file_len_limit) == audio_file_path
                ), f'Limit {audio_file_len_limit}: expected {audio_file_path} to be returned.'

                audio_file_path_with_user = os.path.join('~', audio_file_path)
                audio_file_path_with_user_expected = os.path.expanduser(audio_file_path_with_user)
                assert (
                    get_full_path(audio_file_path_with_user, audio_file_len_limit=audio_file_len_limit)
                    == audio_file_path_with_user_expected
                ), f'Limit {audio_file_len_limit}: expected {audio_file_path_with_user_expected} to be returned.'

    @pytest.mark.unit
    def test_get_full_path_invalid_type(self):
        """Make sure exceptions are raised when audio_file is not a string or a list of strings."""

        with pytest.raises(ValueError, match="Unexpected audio_file type"):
            get_full_path(1)

        with pytest.raises(ValueError, match="Unexpected audio_file type"):
            get_full_path(('a', 'b', 'c'))

        with pytest.raises(ValueError, match="Unexpected audio_file type"):
            get_full_path({'a': 1, 'b': 2, 'c': 3})

        with pytest.raises(ValueError, match="Unexpected audio_file type"):
            get_full_path([1, 2, 3])

    @pytest.mark.unit
    def test_get_full_path_invalid_relative_path(self):
        """Make sure exceptions are raised when audio_file is a relative path and
        manifest is not provided or both manifest and data dir are provided simultaneously.
        """
        with pytest.raises(ValueError, match="Use either manifest_file or data_dir"):
            # Using a relative path without manifest_file or data_dir is not allowed
            get_full_path('relative/path')

        with pytest.raises(ValueError, match="Parameters manifest_file and data_dir cannot be used simultaneously."):
            # Using a relative path without both manifest_file or data_dir is not allowed
            get_full_path('relative/path', manifest_file='/manifest_dir/file.json', data_dir='/data/dir')

    @pytest.mark.unit
    def test_is_tarred_dataset(self):
        # 1) is tarred dataset
        assert is_tarred_dataset("_file_1.wav", "tarred_audio_manifest.json")
        assert is_tarred_dataset("_file_1.wav", "./sharded_manifests/manifest_1.json")

        # 2) is not tarred dataset
        assert not is_tarred_dataset("./file_1.wav", "audio_manifest.json")
        assert not is_tarred_dataset("./file_1.wav", "./sharded_manifests/manifest_test.json")
        assert not is_tarred_dataset("file_1.wav", "audio_manifest.json")
        assert not is_tarred_dataset("file_1.wav", "./sharded_manifests/manifest_test.json")
        assert not is_tarred_dataset("/data/file_1.wav", "audio_manifest.json")
        assert not is_tarred_dataset("/data/file_1.wav", "./sharded_manifests/manifest_test.json")
        assert not is_tarred_dataset("_file_1.wav", "audio_manifest.json")
        assert not is_tarred_dataset("_file_1.wav", "./sharded_manifests/manifest_test.json")

        # 3) no manifest file, treated as non-tarred dataset
        assert not is_tarred_dataset("_file_1.wav", None)