File size: 4,577 Bytes
f4cade0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# mypy: allow-untyped-defs
import warnings

import torch
import torch.distributed.algorithms.model_averaging.averagers as averagers


class PostLocalSGDOptimizer(torch.optim.Optimizer):
    r"""

    Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD <https://arxiv.org/abs/1808.07217>`_,

    This optimizer runs local optimizer at every step.

    After the warm-up stage, it averages parameters periodically after the local optimizer is applied.



    Args:

        optim: The local optimizer.

        averager: A model averager instance to run post-localSGD algorithm.



    Example::



        >>> # xdoctest: +SKIP("undefined variables")

        >>> import torch

        >>> import torch.distributed as dist

        >>> import torch.distributed.algorithms.model_averaging.averagers as averagers

        >>> import torch.nn as nn

        >>> from torch.distributed.optim import PostLocalSGDOptimizer

        >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (

        >>>   PostLocalSGDState,

        >>>   post_localSGD_hook,

        >>> )

        >>>

        >>> model = nn.parallel.DistributedDataParallel(

        >>>    module, device_ids=[rank], output_device=rank

        >>> )

        >>>

        >>> # Register a post-localSGD communication hook.

        >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)

        >>> model.register_comm_hook(state, post_localSGD_hook)

        >>>

        >>> # Create a post-localSGD optimizer that wraps a local optimizer.

        >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as

        >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.

        >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)

        >>> opt = PostLocalSGDOptimizer(

        >>>     optim=local_optim,

        >>>     averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)

        >>> )

        >>>

        >>> # In the first 100 steps, DDP runs global gradient averaging at every step.

        >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),

        >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.

        >>> for step in range(0, 200):

        >>>    opt.zero_grad()

        >>>    loss = loss_fn(output, labels)

        >>>    loss.backward()

        >>>    opt.step()

    """

    def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager):
        self.optim = optim
        self.param_groups = self.optim.param_groups
        self.averager = averager

    @property
    def state(self):  # type: ignore[override]
        return self.optim.state

    def __repr__(self):
        return self.optim.__repr__()

    def state_dict(self):
        r"""

        This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`,

        but adds an extra entry to record model averager's step to the checkpoint

        to ensure reload does not cause unnecessary warm up again.

        """
        optim_state_dict = self.optim.state_dict()
        optim_state_dict["step"] = self.averager.step
        return optim_state_dict

    def load_state_dict(self, state_dict):
        r"""

        This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`,

        but also restores model averager's step value to the one

        saved in the provided ``state_dict``.



        If there is no ``"step"`` entry in ``state_dict``,

        it will raise a warning and initialize the model averager's step to 0.

        """
        self.optim.load_state_dict(state_dict)
        if "step" in state_dict:
            self.averager.step = state_dict["step"]
        else:
            warnings.warn(
                "Loaded state dict does not contain a step counter for an averager. "
                "Setting step counter to 0."
            )
            self.averager.step = 0

    def step(self):  # type: ignore[override]
        r"""

        Performs a single optimization step (parameter update).

        """
        self.optim.step()
        self.averager.average_parameters(params=self.param_groups)

    def zero_grad(self, set_to_none: bool = True):  # type: ignore[override]
        self.optim.zero_grad(set_to_none=set_to_none)

    def add_param_group(self, param_group):
        self.optim.add_param_group(param_group)