File size: 3,080 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from abc import ABC, abstractmethod
from collections import defaultdict

import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from internlm.core.context import global_context as gpc


class BaseGradientHandler(ABC):
    """A basic helper class to handle all-reduce operations of gradients across different parallel groups
    before optimization.

    Args:
        model (Module): Model where the gradients accumulate.
        optimizer (Optimizer): Optimizer for updating the parameters.
    """

    def __init__(self, model, optimizer):
        self._model = model
        self._optimizer = optimizer

    @abstractmethod
    def handle_gradient(self):
        """A method to accumulate gradients across different parallel groups. Users should
        write their own functions or just use the functions in pre-defined subclasses.
        """
        pass


class PipelineSharedModuleGradientHandler(BaseGradientHandler):
    """A helper class to handle all-reduce operations in sub parallel groups.
    A all-reduce collective communication will be operated in
    :func:`handle_gradient` among all sub pipeline parallel groups.
    For better performance, it bucketizes the gradients of all parameters that are
    the same type to improve the efficiency of communication.

    Args:
        model (Module): Model where the gradients accumulate.
        optimizer (Optimizer): Optimizer for updating the parameters.
    """

    def handle_gradient(self):
        """A method running a all-reduce operation in sub pipeline parallel groups."""
        if gpc.pipeline_parallel_size > 1:
            # bucketize and all-reduce
            buckets = defaultdict(lambda: defaultdict(list))
            # Pack the buckets.
            for param in self._model.parameters():
                group = getattr(param, "pipeline_shared_module_pg", None)
                if (
                    param.requires_grad
                    and group is not None
                    and (
                        (hasattr(param, "colo_attr") and not param.colo_attr.saved_grad.is_null())
                        or param.grad is not None
                    )
                ):
                    tp = param.data.type()
                    buckets[group][tp].append(param)

            # For each bucket, all-reduce and copy all-reduced grads.
            for group, group_buckets in buckets.items():
                for tp, bucket in group_buckets.items():
                    grads = [
                        param.colo_attr.grad_payload if hasattr(param, "colo_attr") else param.grad.data
                        for param in bucket
                    ]
                    coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
                    dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
                    for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
                        buf.copy_(synced)