|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import fnmatch |
|
|
import logging |
|
|
import os |
|
|
import tarfile |
|
|
|
|
|
from typing import IO, Union |
|
|
|
|
|
LOGGER = logging.getLogger("NeMo") |
|
|
|
|
|
try: |
|
|
from zarr.storage import BaseStore |
|
|
|
|
|
HAVE_ZARR = True |
|
|
except Exception as e: |
|
|
LOGGER.warning(f"Cannot import zarr, support for zarr-based checkpoints is not available. {type(e).__name__}: {e}") |
|
|
BaseStore = object |
|
|
HAVE_ZARR = False |
|
|
|
|
|
|
|
|
class TarPath: |
|
|
""" |
|
|
A class that represents a path inside a TAR archive and behaves like pathlib.Path. |
|
|
|
|
|
Expected use is to create a TarPath for the root of the archive first, and then derive |
|
|
paths to other files or directories inside the archive like so: |
|
|
|
|
|
with TarPath('/path/to/archive.tar') as archive: |
|
|
myfile = archive / 'filename.txt' |
|
|
if myfile.exists(): |
|
|
data = myfile.read() |
|
|
... |
|
|
|
|
|
Only read and enumeration operations are supported. |
|
|
""" |
|
|
|
|
|
def __init__(self, tar: Union[str, tarfile.TarFile, 'TarPath'], *parts): |
|
|
self._needs_to_close = False |
|
|
self._relpath = '' |
|
|
if isinstance(tar, TarPath): |
|
|
self._tar = tar._tar |
|
|
self._relpath = os.path.join(tar._relpath, *parts) |
|
|
elif isinstance(tar, tarfile.TarFile): |
|
|
self._tar = tar |
|
|
if parts: |
|
|
self._relpath = os.path.join(*parts) |
|
|
elif isinstance(tar, str): |
|
|
self._needs_to_close = True |
|
|
self._tar = tarfile.open(tar, 'r') |
|
|
if parts: |
|
|
self._relpath = os.path.join(*parts) |
|
|
else: |
|
|
raise ValueError(f"Unexpected argument type for TarPath: {type(tar).__name__}") |
|
|
|
|
|
def __del__(self): |
|
|
if self._needs_to_close: |
|
|
self._tar.close() |
|
|
|
|
|
def __truediv__(self, key) -> 'TarPath': |
|
|
return TarPath(self._tar, os.path.join(self._relpath, key)) |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return os.path.join(self._tar.name, self._relpath) |
|
|
|
|
|
@property |
|
|
def tarobject(self): |
|
|
""" |
|
|
Returns the wrapped tar object. |
|
|
""" |
|
|
return self._tar |
|
|
|
|
|
@property |
|
|
def relpath(self): |
|
|
""" |
|
|
Returns the relative path of the path. |
|
|
""" |
|
|
return self._relpath |
|
|
|
|
|
@property |
|
|
def name(self): |
|
|
""" |
|
|
Returns the name of the path. |
|
|
""" |
|
|
return os.path.split(self._relpath)[1] |
|
|
|
|
|
@property |
|
|
def suffix(self): |
|
|
""" |
|
|
Returns the suffix of the path. |
|
|
""" |
|
|
name = self.name |
|
|
i = name.rfind('.') |
|
|
if 0 < i < len(name) - 1: |
|
|
return name[i:] |
|
|
else: |
|
|
return '' |
|
|
|
|
|
def __enter__(self): |
|
|
self._tar.__enter__() |
|
|
return self |
|
|
|
|
|
def __exit__(self, *args): |
|
|
return self._tar.__exit__(*args) |
|
|
|
|
|
def exists(self): |
|
|
""" |
|
|
Checks if the path exists. |
|
|
""" |
|
|
try: |
|
|
self._tar.getmember(self._relpath) |
|
|
return True |
|
|
except KeyError: |
|
|
try: |
|
|
self._tar.getmember(os.path.join('.', self._relpath)) |
|
|
return True |
|
|
except KeyError: |
|
|
return False |
|
|
|
|
|
def is_file(self): |
|
|
""" |
|
|
Checks if the path is a file. |
|
|
""" |
|
|
try: |
|
|
self._tar.getmember(self._relpath).isreg() |
|
|
return True |
|
|
except KeyError: |
|
|
try: |
|
|
self._tar.getmember(os.path.join('.', self._relpath)).isreg() |
|
|
return True |
|
|
except KeyError: |
|
|
return False |
|
|
|
|
|
def is_dir(self): |
|
|
""" |
|
|
Checks if the path is a directory. |
|
|
""" |
|
|
try: |
|
|
self._tar.getmember(self._relpath).isdir() |
|
|
return True |
|
|
except KeyError: |
|
|
try: |
|
|
self._tar.getmember(os.path.join('.', self._relpath)).isdir() |
|
|
return True |
|
|
except KeyError: |
|
|
return False |
|
|
|
|
|
def open(self, mode: str) -> IO[bytes]: |
|
|
""" |
|
|
Opens a file in the archive. |
|
|
""" |
|
|
if mode != 'r' and mode != 'rb': |
|
|
raise NotImplementedError() |
|
|
|
|
|
file = None |
|
|
try: |
|
|
|
|
|
file = self._tar.extractfile(self._relpath) |
|
|
except KeyError: |
|
|
try: |
|
|
|
|
|
file = self._tar.extractfile(os.path.join('.', self._relpath)) |
|
|
except KeyError: |
|
|
raise FileNotFoundError() |
|
|
|
|
|
if file is None: |
|
|
raise FileNotFoundError() |
|
|
|
|
|
return file |
|
|
|
|
|
def glob(self, pattern): |
|
|
""" |
|
|
Returns an iterator over the files in the directory, matching the pattern. |
|
|
""" |
|
|
for member in self._tar.getmembers(): |
|
|
|
|
|
name = member.name[2:] if member.name.startswith('./') else member.name |
|
|
|
|
|
|
|
|
if self._relpath: |
|
|
if not name.startswith(self._relpath + '/'): |
|
|
continue |
|
|
name = name[len(self._relpath) + 1 :] |
|
|
|
|
|
|
|
|
if fnmatch.fnmatch(name, pattern): |
|
|
yield TarPath(self._tar, os.path.join(self._relpath, name)) |
|
|
|
|
|
def rglob(self, pattern): |
|
|
""" |
|
|
Returns an iterator over the files in the directory, including subdirectories. |
|
|
""" |
|
|
for member in self._tar.getmembers(): |
|
|
|
|
|
name = member.name[2:] if member.name.startswith('./') else member.name |
|
|
|
|
|
|
|
|
if self._relpath: |
|
|
if not name.startswith(self._relpath + '/'): |
|
|
continue |
|
|
name = name[len(self._relpath) + 1 :] |
|
|
|
|
|
|
|
|
parts = name.split('/') |
|
|
for i in range(len(parts)): |
|
|
subname = '/'.join(parts[i:]) |
|
|
if fnmatch.fnmatch(subname, pattern): |
|
|
yield TarPath(self._tar, os.path.join(self._relpath, name)) |
|
|
break |
|
|
|
|
|
def iterdir(self): |
|
|
""" |
|
|
Returns an iterator over the files in the directory. |
|
|
""" |
|
|
return self.glob('*') |
|
|
|
|
|
|
|
|
class ZarrPathStore(BaseStore): |
|
|
""" |
|
|
An implementation of read-only Store for zarr library |
|
|
that works with pathlib.Path or TarPath objects. |
|
|
""" |
|
|
|
|
|
def __init__(self, tarpath: TarPath): |
|
|
assert HAVE_ZARR, "Package zarr>=2.18.2,<3.0.0 is required to use ZarrPathStore" |
|
|
self._path = tarpath |
|
|
self._writable = False |
|
|
self._erasable = False |
|
|
|
|
|
def __getitem__(self, key): |
|
|
with (self._path / key).open('rb') as file: |
|
|
return file.read() |
|
|
|
|
|
def __contains__(self, key): |
|
|
return (self._path / key).is_file() |
|
|
|
|
|
def __iter__(self): |
|
|
return self.keys() |
|
|
|
|
|
def __len__(self): |
|
|
return sum(1 for _ in self.keys()) |
|
|
|
|
|
def __setitem__(self, key, value): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def __delitem__(self, key): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def keys(self): |
|
|
""" |
|
|
Returns an iterator over the keys in the store. |
|
|
""" |
|
|
return self._path.iterdir() |
|
|
|