File size: 3,756 Bytes
ad5f26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from concurrent.futures import Future
from typing import Any, Optional

import torch.distributed as dist
import torch.distributed.checkpoint.state_dict_loader as loader
import torch.distributed.checkpoint.state_dict_saver as saver
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
from torch.distributed.checkpoint.storage import (
    LoadPlanner,
    SavePlanner,
    StorageReader,
    StorageWriter,
)


__all__: list[str] = []


class _Checkpointer:
    """This base class specefies a high level API for saving and loading

    distributed `state_dict` 's. It provides an abstraction over the low-level APIs

    provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling

    :py:meth: `torch.distributed.state_dict_saver.save` and

    :py:meth: `torch.distributed.state_dict_loader.load` with the provided storage

    readers and writers.



    .. warning::

        This feature is experimental and subject to removal/change.



    """

    def __init__(

        self,

        storage_writer: StorageWriter,

        storage_reader: StorageReader,

        *,

        process_group: Optional[dist.ProcessGroup] = None,

        coordinator_rank: int = 0,

        no_dist: bool = False,

        load_planner: Optional[LoadPlanner] = None,

        save_planner: Optional[SavePlanner] = None,

    ):
        """Initializes the Checkpointer instance.



        Args:

            storage_writer: Instance of StorageWrite use to perform writes.

            storage_reader: StorageReader used to load data from.

            process_group: ProcessGroup to be used for cross-rank synchronization.

            coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default.

            no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``)

            loader_planner: Instance of LoadPlanner to use when loading.

            save_planner: Instance of SavePlanner to use when saving.

        """
        self.storage_writer = storage_writer
        self.storage_reader = storage_reader
        self.process_group = process_group
        self.coordinator_rank = coordinator_rank
        self.no_dist = no_dist
        self.load_planner = load_planner
        self.save_planner = save_planner

    def save(

        self,

        state_dict: STATE_DICT_TYPE,

    ) -> Metadata:
        """Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization."""
        return saver.save(
            state_dict,
            self.storage_writer,
            process_group=self.process_group,
            coordinator_rank=self.coordinator_rank,
            no_dist=self.no_dist,
            planner=self.save_planner,
        )

    def async_save(

        self,

        state_dict: STATE_DICT_TYPE,

    ) -> Future:
        """

        Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization.



        Returns:

            Future: A future holding the resultant Metadata object from `save`.

        """
        return saver.async_save(
            state_dict,
            storage_writer=self.storage_writer,
            process_group=self.process_group,
            planner=self.save_planner,
        )

    def load(self, state_dict: dict[str, Any]) -> None:
        """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
        loader.load(
            state_dict,
            storage_reader=self.storage_reader,
            process_group=self.process_group,
            planner=self.load_planner,
        )