File size: 12,158 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pathlib
import shutil
import subprocess
from functools import lru_cache
from typing import Any, Callable, Dict, Iterable, Tuple
from urllib.parse import urlparse

try:
    from nemo import __version__ as NEMO_VERSION
except ImportError:
    NEMO_VERSION = 'git'

from nemo import constants
from nemo.utils import logging
from nemo.utils.nemo_logging import LogMode

try:
    from lhotse.serialization import open_best as lhotse_open_best

    LHOTSE_AVAILABLE = True
except ImportError:
    LHOTSE_AVAILABLE = False


def resolve_cache_dir() -> pathlib.Path:
    """
    Utility method to resolve a cache directory for NeMo that can be overriden by an environment variable.

    Example:
        NEMO_CACHE_DIR="~/nemo_cache_dir/" python nemo_example_script.py

    Returns:
        A Path object, resolved to the absolute path of the cache directory. If no override is provided,
        uses an inbuilt default which adapts to nemo versions strings.
    """
    override_dir = os.environ.get(constants.NEMO_ENV_CACHE_DIR, "")
    if override_dir == "":
        path = pathlib.Path.joinpath(pathlib.Path.home(), f'.cache/torch/NeMo/NeMo_{NEMO_VERSION}')
    else:
        path = pathlib.Path(override_dir).resolve()
    return path


def is_datastore_path(path) -> bool:
    """Check if a path is from a data object store."""
    try:
        result = urlparse(path)
        return bool(result.scheme) and bool(result.netloc)
    except AttributeError:
        return False


def is_tarred_path(path) -> bool:
    """Check if a path is for a tarred file."""
    return path.endswith('.tar')


def is_datastore_cache_shared() -> bool:
    """Check if store cache is shared."""
    # Assume cache is shared by default, e.g., as in resolve_cache_dir (~/.cache)
    cache_shared = int(os.environ.get(constants.NEMO_ENV_DATA_STORE_CACHE_SHARED, 1))

    if cache_shared == 0:
        return False
    elif cache_shared == 1:
        return True
    else:
        raise ValueError(f'Unexpected value of env {constants.NEMO_ENV_DATA_STORE_CACHE_SHARED}')


def ais_cache_base() -> str:
    """Return path to local cache for AIS."""
    override_dir = os.environ.get(constants.NEMO_ENV_DATA_STORE_CACHE_DIR, "")
    if override_dir == "":
        cache_dir = resolve_cache_dir().as_posix()
    else:
        cache_dir = pathlib.Path(override_dir).resolve().as_posix()

    if cache_dir.endswith(NEMO_VERSION):
        # Prevent re-caching dataset after upgrading NeMo
        cache_dir = os.path.dirname(cache_dir)
    return os.path.join(cache_dir, 'ais')


def ais_endpoint() -> str:
    """Get configured AIS endpoint."""
    return os.getenv('AIS_ENDPOINT')


def bucket_and_object_from_uri(uri: str) -> Tuple[str, str]:
    """Parse a path to determine bucket and object path.

    Args:
        uri: Full path to an object on an object store

    Returns:
        Tuple of strings (bucket_name, object_path)
    """
    if not is_datastore_path(uri):
        raise ValueError(f'Provided URI is not a valid store path: {uri}')
    uri_parts = pathlib.PurePath(uri).parts
    bucket = uri_parts[1]
    object_path = pathlib.PurePath(*uri_parts[2:])

    return str(bucket), str(object_path)


def ais_endpoint_to_dir(endpoint: str) -> str:
    """Convert AIS endpoint to a valid dir name.
    Used to build cache location.

    Args:
        endpoint: AIStore endpoint in format https://host:port

    Returns:
        Directory formed as `host/port`.
    """
    result = urlparse(endpoint)
    if not result.hostname or not result.port:
        raise ValueError(f"Unexpected format for ais endpoint: {endpoint}")
    return os.path.join(result.hostname, str(result.port))


@lru_cache(maxsize=1)
def ais_binary() -> str:
    """Return location of `ais` binary if available."""
    path = shutil.which('ais')

    if path is not None:
        logging.debug('Found AIS binary at %s', path)
        return path

    # Double-check if it exists at the default path
    default_path = '/usr/local/bin/ais'
    if os.path.isfile(default_path):
        logging.info('ais available at the default path: %s', default_path, mode=LogMode.ONCE)
        return default_path
    else:
        logging.warning(
            f'AIS binary not found with `which ais` and at the default path {default_path}.', mode=LogMode.ONCE
        )
        return None


def datastore_path_to_local_path(store_path: str) -> str:
    """Convert a data store path to a path in a local cache.

    Args:
        store_path: a path to an object on an object store

    Returns:
        Path to the same object in local cache.
    """
    if is_datastore_path(store_path):
        endpoint = ais_endpoint()
        if not endpoint:
            raise RuntimeError(f'AIS endpoint not set, cannot resolve {store_path}')

        local_ais_cache = os.path.join(ais_cache_base(), ais_endpoint_to_dir(endpoint))
        store_bucket, store_object = bucket_and_object_from_uri(store_path)
        local_path = os.path.join(local_ais_cache, store_bucket, store_object)
    else:
        raise ValueError(f'Unexpected store path format: {store_path}')

    return local_path


def open_datastore_object_with_binary(path: str, num_retries: int = 5):
    """Open a datastore object and return a file-like object.

    Args:
        path: path to an object
        num_retries: number of retries if the get command fails with ais binary, as AIS Python SDK has its own retry mechanism

    Returns:
        File-like object that supports read()
    """

    if is_datastore_path(path):
        endpoint = ais_endpoint()
        if endpoint is None:
            raise RuntimeError(f'AIS endpoint not set, cannot resolve {path}')

        binary = ais_binary()

        if not binary:
            raise RuntimeError(
                f"AIS binary is not found, cannot resolve {path}. Please either install it or install Lhotse with `pip install lhotse`.\n"
                "Lhotse's native open_best supports AIS Python SDK, which is the recommended way to operate with the data from AIStore.\n"
                "See AIS binary installation instructions at https://github.com/NVIDIA/aistore?tab=readme-ov-file#install-from-release-binaries.\n"
            )

        cmd = f'{binary} get {path} -'

        done = False

        for _ in range(num_retries):
            proc = subprocess.Popen(
                cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=False  # bytes mode
            )
            stream = proc.stdout
            if stream.peek(1):
                done = True
                break

        if not done:
            error = proc.stderr.read().decode("utf-8", errors="ignore").strip()
            raise ValueError(
                f"{path} couldn't be opened with AIS binary after {num_retries} attempts because of the following exception: {error}"
            )

        return stream


def open_best(path: str, mode: str = "rb"):
    if LHOTSE_AVAILABLE:
        return lhotse_open_best(path, mode=mode)
    if is_datastore_path(path):
        return open_datastore_object_with_binary(path)
    return open(path, mode=mode)


def get_datastore_object(path: str, force: bool = False, num_retries: int = 5) -> str:
    """Download an object from a store path and return the local path.
    If the input `path` is a local path, then nothing will be done, and
    the original path will be returned.

    Args:
        path: path to an object
        force: force download, even if a local file exists
        num_retries: number of retries if the get command fails with ais binary, as AIS Python SDK has its own retry mechanism

    Returns:
        Local path of the object.
    """
    if is_datastore_path(path):

        local_path = datastore_path_to_local_path(store_path=path)

        if not os.path.isfile(local_path) or force:
            # Either we don't have the file in cache or we force download it
            # Enhancement: if local file is present, check some tag and compare against remote
            local_dir = os.path.dirname(local_path)
            if not os.path.isdir(local_dir):
                os.makedirs(local_dir, exist_ok=True)

            with open(local_path, 'wb') as f:
                f.write(open_best(path).read(), num_retries=num_retries)

        return local_path

    else:
        # Assume the file is local
        return path


class DataStoreObject:
    """A simple class for handling objects in a data store.
    Currently, this class supports objects on AIStore.

    Args:
        store_path: path to a store object
        local_path: path to a local object, may be used to upload local object to store
        get: get the object from a store
    """

    def __init__(self, store_path: str, local_path: str = None, get: bool = False):
        if local_path is not None:
            raise NotImplementedError('Specifying a local path is currently not supported.')

        self._store_path = store_path
        self._local_path = local_path

        if get:
            self.get()

    @property
    def store_path(self) -> str:
        """Return store path of the object."""
        return self._store_path

    @property
    def local_path(self) -> str:
        """Return local path of the object."""
        return self._local_path

    def get(self, force: bool = False) -> str:
        """Get an object from the store to local cache and return the local path.

        Args:
            force: force download, even if a local file exists

        Returns:
            Path to a local object.
        """
        if not self.local_path:
            # Assume the object needs to be downloaded
            self._local_path = get_datastore_object(self.store_path, force=force)
        return self.local_path

    def put(self, force: bool = False) -> str:
        """Push to remote and return the store path

        Args:
            force: force download, even if a local file exists

        Returns:
            Path to a (remote) object object on the object store.
        """
        raise NotImplementedError()

    def __str__(self):
        """Return a human-readable description of the object."""
        description = f'{type(self)}: store_path={self.store_path}, local_path={self.local_path}'
        return description


def datastore_object_get(store_object: DataStoreObject) -> bool:
    """A convenience wrapper for multiprocessing.imap.

    Args:
        store_object: An instance of DataStoreObject

    Returns:
        True if get() returned a path.
    """
    return store_object.get() is not None


def wds_url_opener(
    data: Iterable[Dict[str, Any]],
    handler: Callable[[Exception], bool],
    **kw: Dict[str, Any],
):
    """
    Open URLs and yield a stream of url+stream pairs.
    This is a workaround to use lhotse's open_best instead of webdataset's default url_opener.
    webdataset's default url_opener uses gopen, which does not support opening datastore paths.

    Args:
        data: Iterator over dict(url=...).
        handler: Exception handler.
        **kw: Keyword arguments for gopen.gopen.

    Yields:
        A stream of url+stream pairs.
    """
    for sample in data:
        assert isinstance(sample, dict), sample
        assert "url" in sample
        url = sample["url"]
        try:
            stream = open_best(url, mode="rb")
            sample.update(stream=stream)
            yield sample
        except Exception as exn:
            if handler(exn):
                continue
            else:
                break