| |
|
|
| import abc |
| import io |
| from collections.abc import Sequence |
| from typing import cast, IO, Optional |
|
|
| |
| from typing_extensions import Buffer |
|
|
| from torch._utils import try_import |
|
|
|
|
| |
| |
|
|
| pyzstd_module_name = "pyzstd" |
| pyzstd = try_import(pyzstd_module_name) |
| zstandard_module_name = "zstandard" |
| zstandard = try_import(zstandard_module_name) |
|
|
|
|
| __all__ = [ |
| "Extension", |
| "StreamTransformExtension", |
| "ZStandard", |
| "ExtensionRegistry", |
| ] |
|
|
|
|
| class Extension(abc.ABC): |
| """ |
| Extensions provide modular additions to functionality within distributed checkpointing, |
| which affect the layout or format of the written artifacts. Extensions may be |
| built into pytorch, or provided externally. |
| |
| When writing, the caller provides a list of extension instances of the appropriate |
| type. Each extension can output a descriptor which is used to reconstitute the |
| extension at read-time. |
| """ |
|
|
| @staticmethod |
| @abc.abstractmethod |
| def registry_name() -> str: |
| """ |
| See ExtensionRegistry.from_descriptor_list |
| """ |
|
|
| @staticmethod |
| @abc.abstractmethod |
| def from_descriptor(version: str) -> "Extension": |
| """ |
| See ExtensionRegistry.from_descriptor_list |
| """ |
|
|
| @abc.abstractmethod |
| def get_descriptor(self) -> str: |
| """ |
| Return descriptor name to be included in metadata. The form should be |
| "extension_name[@local-domain][/version]". |
| """ |
|
|
|
|
| class StreamTransformExtension(Extension): |
| """ |
| An extension which performs transformation on a byte stream, such as compression |
| or encryption. |
| |
| Implementations should try to be memory friendly and performant. For example, don't |
| read the whole input, then transform it, and write it back. If at all possible, do it in |
| chunks. But, don't read/transform/write one byte at a time, either. |
| """ |
|
|
| @abc.abstractmethod |
| def transform_to(self, output: IO[bytes]) -> IO[bytes]: |
| """ |
| Takes a writeable output stream, and generates a new stream which implements the |
| output transform. Input data written to the returned stream will be transformed |
| and written to the `output` argument stream. |
| """ |
|
|
| @abc.abstractmethod |
| def transform_from(self, input: IO[bytes]) -> IO[bytes]: |
| """ |
| Takes a readable input stream, and generates a new stream which implements the |
| input transform. When the returned stream is read, data will be read from the |
| 'input' stream, transformed, and returned. |
| """ |
|
|
|
|
| class ZStandard(StreamTransformExtension): |
| @staticmethod |
| def is_available() -> bool: |
| return zstandard is not None or pyzstd is not None |
|
|
| @staticmethod |
| def from_descriptor(version: str) -> "ZStandard": |
| if version.partition(".")[0] != "1": |
| raise ValueError(f"Unknown extension {version=}") |
| if not ZStandard.is_available(): |
| raise ValueError( |
| f"Stream with ZStandard compression cannot be processed because " |
| f"no module named '{zstandard_module_name}' or '{pyzstd_module_name}'" |
| ) |
| return ZStandard() |
|
|
| @staticmethod |
| def registry_name() -> str: |
| return "stream.zstd" |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| if not ZStandard.is_available(): |
| raise ValueError( |
| f"ZStandard extension is unavailable because no module named '{zstandard_module_name}' or '{pyzstd_module_name}'" |
| ) |
|
|
| def get_descriptor(self) -> str: |
| return f"{self.registry_name()}/1" |
|
|
| def transform_to(self, output: IO[bytes]) -> IO[bytes]: |
| if zstandard is not None: |
| compressor = zstandard.ZstdCompressor() |
| return compressor.stream_writer(output) |
|
|
| class Writer(io.RawIOBase): |
| def __init__(self, output: IO[bytes]) -> None: |
| self.output = output |
| self.compressor = pyzstd.ZstdCompressor() |
|
|
| def writeable(self) -> bool: |
| return True |
|
|
| def write(self, b: Buffer) -> Optional[int]: |
| outdata = self.compressor.compress(b) |
| if outdata: |
| self.output.write(outdata) |
| return len(memoryview(b)) |
|
|
| def flush(self) -> None: |
| outdata = self.compressor.flush() |
| if outdata: |
| self.output.write(outdata) |
| self.output.flush() |
|
|
| return cast(IO[bytes], Writer(output)) |
|
|
| def transform_from(self, input: IO[bytes]) -> IO[bytes]: |
| if zstandard is not None: |
| decompressor = zstandard.ZstdDecompressor() |
| return decompressor.stream_reader(input) |
|
|
| class Reader(io.RawIOBase): |
| def __init__(self, input: IO[bytes]) -> None: |
| self.input = input |
| self.decompressor = pyzstd.EndlessZstdDecompressor() |
|
|
| def readable(self) -> bool: |
| return True |
|
|
| def readinto(self, b: Buffer) -> Optional[int]: |
| |
| |
| |
| |
| |
|
|
| if self.decompressor.needs_input: |
| indata = self.input.read((128 + 6) * 1024) |
| else: |
| indata = b"" |
|
|
| bview = memoryview(b) |
| blen = len(bview) |
| outdata = self.decompressor.decompress(indata, blen) |
| if outdata is None: |
| return None |
|
|
| count = len(outdata) |
| bview[:count] = outdata |
| return count |
|
|
| def seekable(self) -> bool: |
| return False |
|
|
| return cast(IO[bytes], Reader(input)) |
|
|
|
|
| class ExtensionRegistry: |
| def __init__(self) -> None: |
| |
| self.extensions: dict[str, type[Extension]] = { |
| cls.registry_name(): cls for cls in (ZStandard,) |
| } |
|
|
| def register(self, cls: type[Extension]) -> None: |
| self.extensions[cls.registry_name()] = cls |
|
|
| def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]: |
| """ |
| Given a seuquence of descriptor strings as returned by |
| Extension.get_descriptor at save time, creates a sequence of |
| Extension instances. The name[@local-domain] preceding the |
| version number is used to look up an implementation class in |
| the registry, and the version is passed to the class's |
| from_descriptor static method. If the registry contains no |
| match, this will throw ValueError. If the from_descriptor |
| method raises an exception, that will pass through to the |
| caller. |
| """ |
|
|
| def from_descriptor(desc: str) -> Extension: |
| name, _, version = desc.partition("/") |
| if version is None: |
| version = 0 |
| ext = self.extensions.get(name) |
| if not ext: |
| raise ValueError(f"Unknown extension {name=}") |
| return ext.from_descriptor(version) |
|
|
| return [from_descriptor(desc) for desc in descriptors] |
|
|