File size: 8,293 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine

from typing import List, Optional

import torch
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import _LRScheduler

from internlm.core.gradient_handler import BaseGradientHandler
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
from internlm.utils.common import get_batch_size, move_to_device


class Engine:
    """
    The Engine class is responsible for managing the training and evaluation process of a neural network model.
    It handles the forward and backward passes, parameter updates, gradient handling, and mode switching between
    training and evaluation.

    Args:
        model (torch.nn.Module): The neural network model to be trained or evaluated.
        optimizer (BaseOptimizer): The optimizer used for updating the parameters of the model.
        lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The learning rate scheduler for the optimizer.
                                                                        Default is None.
        beta2_scheduler (internlm.solver.beta2_scheduler.Beta2Scheduler, optional): The beta2 scheduler for the
                                                                                    optimizer. Default is None.
        criterion (torch.nn.modules.loss._Loss, optional): The loss function used for calculating the loss during
                                                           training. Default is None.
        gradient_handlers (List[BaseGradientHandler], optional): A list of gradient handlers used in the backward pass.
                                                                 Default is None.
        clip_grad_norm (float, optional): The norm value for gradient clipping. Default is 0.0.

    Examples:
        >>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
        >>> model = ...
        >>> criterion = ...
        >>> optimizer = ...
        >>> train_dataloader = ...
        >>> engine, _, _, _ = internlm.initialize_engine(model, optimizer, criterion)
        >>> engine.train()
        >>> for inputs, labels in train_dataloader
        >>>     # set gradients to zero
        >>>     engine.zero_grad()
        >>>     # run forward pass
        >>>     outputs = engine(inputs)
        >>>     # compute loss value and run backward pass
        >>>     loss = engine.criterion(outputs, labels)
        >>>     engine.backward(loss)
        >>>     # update parameters
        >>>     engine.step()
    """

    def __init__(
        self,
        model: Module,
        optimizer: BaseOptimizer,
        lr_scheduler: Optional[_LRScheduler] = None,
        beta2_scheduler: Optional[Beta2Scheduler] = None,
        criterion: Optional[_Loss] = None,
        gradient_handlers: Optional[List[BaseGradientHandler]] = None,
        clip_grad_norm: float = 0.0,
    ):
        self._model = model
        self._optimizer = optimizer
        self._lr_scheduler = lr_scheduler
        self._beta2_scheduler = beta2_scheduler
        self._criterion = criterion
        self._clip_grad_norm = clip_grad_norm

        # state
        self.training = True  # default

        # build gradient handler
        self._gradient_handlers = gradient_handlers if gradient_handlers else []

    @property
    def model(self):
        """Returns the model attached to the engine."""
        return self._model

    @property
    def optimizer(self):
        """Returns the optimizer attached to the engine."""
        return self._optimizer

    @property
    def criterion(self):
        """Returns the criterion (loss function) attached to the engine."""
        return self._criterion

    def _all_reduce_gradients(self):
        """Handles all-reduce operations of gradients across different parallel groups."""
        for handler in self._gradient_handlers:
            handler.handle_gradient()

    def zero_grad(self):
        """Sets the gradient of all parameters in the model to zero."""
        self.optimizer.zero_grad()

    def step(self):
        """
        Executes the parameter update step. This includes all-reduce operations of gradients, gradient clipping,
        and parameter update. If successful, it also steps the learning rate scheduler and beta2 scheduler
        if they exist.

        Returns:
            success (bool): Whether the parameter update was successful.
            grad_norm (float): The norm of the gradient after clipping.
        """
        self._all_reduce_gradients()
        self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)

        success, grad_norm = self.optimizer.step()

        if success and self._lr_scheduler is not None:
            self._lr_scheduler.step()

        if success and self._beta2_scheduler is not None:
            self._beta2_scheduler.step()

        return success, grad_norm

    def train(self):
        """Sets the model to training mode."""
        self.training = True
        self._model.train()

    def eval(self):
        """Sets the model to evaluation mode."""
        self.training = False
        self._model.eval()

    def backward(self, loss: torch.Tensor):
        """
        Starts the backward propagation given the loss value computed by a loss function.

        Args:
            loss (torch.Tensor): The loss value computed by a loss function.
        """
        return self.optimizer.backward(loss)

    def backward_by_grad(self, tensor, grad):
        """
        Starts the backward propagation given the gradient of the output tensor.

        Args:
            tensor (torch.Tensor): The output tensor.
            grad (torch.Tensor): The gradient passed back to the output tensor.
        """
        return self.optimizer.backward_by_grad(tensor, grad)

    def __call__(self, *args, **kwargs):
        """
        Runs the forward step for the model.

        Returns:
            torch.Tensor: The output of the model.
        """
        return self.model(*args, **kwargs)

    def load_batch(self, data_iter, to_gpu=True):
        """
        Loads a batch from the data iterator. It returns the data and labels which are
        already in the same GPU as where the model is.

        Args:
            data_iter (Iterable): The data iterator from which to get a batch of data, obtained by calling
                                  iter(dataloader).
            to_gpu (bool, optional): Whether the data should be moved to the GPU. Default is True.

        Returns:
            Tuple (torch.Tensor, torch.Tensor): A tuple of (data, label).
        """
        if data_iter is None:
            raise RuntimeError("Dataloader is not defined.")
        try:
            batch_data = next(data_iter)
        except TypeError:
            batch_data = data_iter

        if to_gpu:
            batch_data = move_to_device(batch_data)
        batch_size = get_batch_size(batch_data)

        return batch_data, batch_size


class KDEngine(Engine):
    def __init__(
            self,
            model: Module,
            teacher: Module,
            optimizer: BaseOptimizer,
            lr_scheduler: Optional[_LRScheduler] = None,
            beta2_scheduler: Optional[Beta2Scheduler] = None,
            criterion: Optional[_Loss] = None,
            kd_criterion: Optional[_Loss] = None,
            gradient_handlers: Optional[List[BaseGradientHandler]] = None,
            clip_grad_norm: float = 0.0,
    ):
        self._teacher = teacher
        self._kd_criterion = kd_criterion

        super().__init__(
            model=model,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            beta2_scheduler=beta2_scheduler,
            criterion=criterion,
            gradient_handlers=gradient_handlers,
            clip_grad_norm=clip_grad_norm,
        )

    @property
    def teacher(self):
        """Returns the model attached to the engine."""
        return self._teacher

    @property
    def kd_criterion(self):
        """Returns the model attached to the engine."""
        return self._kd_criterion