| |
| |
|
|
| |
|
|
| from typing import Any |
|
|
| import torch |
| import torch.distributed as dist |
| from torch import Tensor, nn |
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors |
| from torch.distributed import ReduceOp |
|
|
| from internlm.core.context import ParallelMode |
| from internlm.core.context.parallel_context import global_context as gpc |
|
|
|
|
| class NaiveAMPModel(nn.Module): |
| """ |
| This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16. |
| It also provides options to cast the output back to fp32 and to synchronize buffers. |
| |
| Args: |
| model (torch.nn.Module): The model to be wrapped and cast into fp16. |
| output_to_fp32 (bool, optional): If True, the output of this module is cast into fp32. Defaults to True. |
| parallel_mode (:class:`internlm.core.context.ParallelMode`): The parallel group mode used in this module. |
| Defaults to ``ParallelMode.DATA``. |
| sync_buffer (bool, optional): If True, the buffers are synchronized. Defaults to True. |
| """ |
|
|
| def __init__( |
| self, |
| model: nn.Module, |
| output_to_fp32: bool = True, |
| parallel_mode: ParallelMode = ParallelMode.DATA, |
| sync_buffer: bool = True, |
| dtype=torch.float16, |
| ): |
| super().__init__() |
| self.model = model.to(dtype) |
| self._output_to_fp32 = output_to_fp32 |
| self._sync_buf = sync_buffer |
| self.dtype = dtype |
|
|
| if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: |
| self._process_group = gpc.get_group(parallel_mode) |
| self._world_size = gpc.get_world_size(parallel_mode) |
| else: |
| self._process_group = None |
| self._world_size = 1 |
| self._sync_buf = False |
| self._first_eval_run = False |
|
|
| @property |
| def sync_buffer(self): |
| """Returns the current state of the buffer synchronization.""" |
| return self._sync_buf |
|
|
| @sync_buffer.setter |
| def sync_buffer(self, state: bool): |
| """Sets the state of the buffer synchronization.""" |
| self._sync_buf = state |
|
|
| def _convert_to_fp16(self, input_: Any): |
| """Converts the input to fp16 if it is a Tensor of dtype float32.""" |
| if isinstance(input_, Tensor) and input_.dtype == torch.float32: |
| input_ = input_.to(self.dtype) |
| return input_ |
|
|
| def _convert_to_fp32(self, input_: Any): |
| """Converts the input to fp32 if it is a Tensor of dtype float16.""" |
| if isinstance(input_, Tensor) and input_.dtype == torch.float16: |
| input_ = input_.float() |
| return input_ |
|
|
| def convert_to_fp32(self, out): |
| """Converts the output to fp32""" |
| if isinstance(out, Tensor): |
| out = self._convert_to_fp32(out) |
| elif isinstance(out, (tuple, list)): |
| out = [self._convert_to_fp32(val) for val in out] |
| elif isinstance(out, dict): |
| out = {key: self._convert_to_fp32(val) for key, val in out.items()} |
|
|
| return out |
|
|
| def _reduce_module_buffer(self): |
| """ |
| All-reduces the buffers (e.g., running stats of batch normalization) across |
| data parallel ranks so that all the ranks will produce consistent results |
| when given the same input. |
| """ |
| buf_list = [] |
|
|
| |
| for buf in self.model.buffers(): |
| if buf is not None: |
| buf_list.append(buf) |
|
|
| |
| if buf_list: |
| coalesced_buf = _flatten_dense_tensors(buf_list) |
| coalesced_buf.div_(self._world_size) |
| dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group) |
| unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list) |
| for old, new in zip(buf_list, unflattened_buf_list): |
| old.copy_(new) |
|
|
| def eval(self): |
| """Sets the model to evaluation mode. Buffers are only synchronized in the first eval iteration.""" |
| self.model.eval() |
| self._first_eval_run = True |
|
|
| def forward(self, *args, **kwargs): |
| """ |
| Performs a forward pass on the model. Buffers are synchronized before the forward pass. |
| The inputs are converted to fp16 and the outputs are optionally converted back to fp32. |
| """ |
| if (self.training or self._first_eval_run) and self._sync_buf: |
| with torch.no_grad(): |
| self._reduce_module_buffer() |
|
|
| if self._first_eval_run: |
| self._first_eval_run = False |
|
|
| if args: |
| args = [self._convert_to_fp16(arg) for arg in args] |
| if kwargs: |
| for k, v in kwargs.items(): |
| kwargs[k] = self._convert_to_fp16(v) |
|
|
| out = self.model(*args, **kwargs) |
|
|
| if self._output_to_fp32: |
| out = self.convert_to_fp32(out) |
| return out |
|
|