File size: 7,924 Bytes
ad5f26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# Copyright (c) Meta Platforms, Inc. and affiliates

import abc
import io
from collections.abc import Sequence
from typing import cast, IO, Optional

# introduced as collections.abc.Buffer in Python 3.12
from typing_extensions import Buffer

from torch._utils import try_import


# NOTE: everything in this file is experimental, and subject to
# change.  Feedback and bug fixes are always welcome.

pyzstd_module_name = "pyzstd"
pyzstd = try_import(pyzstd_module_name)
zstandard_module_name = "zstandard"
zstandard = try_import(zstandard_module_name)


__all__ = [
    "Extension",
    "StreamTransformExtension",
    "ZStandard",
    "ExtensionRegistry",
]


class Extension(abc.ABC):
    """

    Extensions provide modular additions to functionality within distributed checkpointing,

    which affect the layout or format of the written artifacts.  Extensions may be

    built into pytorch, or provided externally.



    When writing, the caller provides a list of extension instances of the appropriate

    type.  Each extension can output a descriptor which is used to reconstitute the

    extension at read-time.

    """

    @staticmethod
    @abc.abstractmethod
    def registry_name() -> str:
        """

        See ExtensionRegistry.from_descriptor_list

        """

    @staticmethod
    @abc.abstractmethod
    def from_descriptor(version: str) -> "Extension":
        """

        See ExtensionRegistry.from_descriptor_list

        """

    @abc.abstractmethod
    def get_descriptor(self) -> str:
        """

        Return descriptor name to be included in metadata.  The form should be

        "extension_name[@local-domain][/version]".

        """


class StreamTransformExtension(Extension):
    """

    An extension which performs transformation on a byte stream, such as compression

    or encryption.



    Implementations should try to be memory friendly and performant.  For example, don't

    read the whole input, then transform it, and write it back.  If at all possible, do it in

    chunks.  But, don't read/transform/write one byte at a time, either.

    """

    @abc.abstractmethod
    def transform_to(self, output: IO[bytes]) -> IO[bytes]:
        """

        Takes a writeable output stream, and generates a new stream which implements the

        output transform.  Input data written to the returned stream will be transformed

        and written to the `output` argument stream.

        """

    @abc.abstractmethod
    def transform_from(self, input: IO[bytes]) -> IO[bytes]:
        """

        Takes a readable input stream, and generates a new stream which implements the

        input transform.  When the returned stream is read, data will be read from the

        'input' stream, transformed, and returned.

        """


class ZStandard(StreamTransformExtension):
    @staticmethod
    def is_available() -> bool:
        return zstandard is not None or pyzstd is not None

    @staticmethod
    def from_descriptor(version: str) -> "ZStandard":
        if version.partition(".")[0] != "1":
            raise ValueError(f"Unknown extension {version=}")
        if not ZStandard.is_available():
            raise ValueError(
                f"Stream with ZStandard compression cannot be processed because "
                f"no module named '{zstandard_module_name}' or '{pyzstd_module_name}'"
            )
        return ZStandard()

    @staticmethod
    def registry_name() -> str:
        return "stream.zstd"

    def __init__(self) -> None:
        super().__init__()
        if not ZStandard.is_available():
            raise ValueError(
                f"ZStandard extension is unavailable because no module named '{zstandard_module_name}' or '{pyzstd_module_name}'"
            )

    def get_descriptor(self) -> str:
        return f"{self.registry_name()}/1"

    def transform_to(self, output: IO[bytes]) -> IO[bytes]:
        if zstandard is not None:
            compressor = zstandard.ZstdCompressor()  # type: ignore[union-attr]
            return compressor.stream_writer(output)

        class Writer(io.RawIOBase):
            def __init__(self, output: IO[bytes]) -> None:
                self.output = output
                self.compressor = pyzstd.ZstdCompressor()  # type: ignore[union-attr]

            def writeable(self) -> bool:
                return True

            def write(self, b: Buffer) -> Optional[int]:
                outdata = self.compressor.compress(b)
                if outdata:
                    self.output.write(outdata)
                return len(memoryview(b))

            def flush(self) -> None:
                outdata = self.compressor.flush()
                if outdata:
                    self.output.write(outdata)
                self.output.flush()

        return cast(IO[bytes], Writer(output))

    def transform_from(self, input: IO[bytes]) -> IO[bytes]:
        if zstandard is not None:
            decompressor = zstandard.ZstdDecompressor()  # type: ignore[union-attr]
            return decompressor.stream_reader(input)

        class Reader(io.RawIOBase):
            def __init__(self, input: IO[bytes]) -> None:
                self.input = input
                self.decompressor = pyzstd.EndlessZstdDecompressor()  # type: ignore[union-attr]

            def readable(self) -> bool:
                return True

            def readinto(self, b: Buffer) -> Optional[int]:
                # This needs to read enough so it can decompress
                # something so the output doesn't look like EOF.  This
                # means reading at least one block.  The max block
                # size is 128KB, so we read that plus some
                # overhead to be sure.

                if self.decompressor.needs_input:
                    indata = self.input.read((128 + 6) * 1024)
                else:
                    indata = b""

                bview = memoryview(b)
                blen = len(bview)
                outdata = self.decompressor.decompress(indata, blen)
                if outdata is None:
                    return None

                count = len(outdata)
                bview[:count] = outdata
                return count

            def seekable(self) -> bool:
                return False

        return cast(IO[bytes], Reader(input))


class ExtensionRegistry:
    def __init__(self) -> None:
        # Populate default registry contents
        self.extensions: dict[str, type[Extension]] = {
            cls.registry_name(): cls for cls in (ZStandard,)
        }

    def register(self, cls: type[Extension]) -> None:
        self.extensions[cls.registry_name()] = cls

    def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]:
        """

        Given a seuquence of descriptor strings as returned by

        Extension.get_descriptor at save time, creates a sequence of

        Extension instances.  The name[@local-domain] preceding the

        version number is used to look up an implementation class in

        the registry, and the version is passed to the class's

        from_descriptor static method.  If the registry contains no

        match, this will throw ValueError.  If the from_descriptor

        method raises an exception, that will pass through to the

        caller.

        """

        def from_descriptor(desc: str) -> Extension:
            name, _, version = desc.partition("/")
            if version is None:
                version = 0
            ext = self.extensions.get(name)
            if not ext:
                raise ValueError(f"Unknown extension {name=}")
            return ext.from_descriptor(version)

        return [from_descriptor(desc) for desc in descriptors]