| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| This file contains utilities to manipulate torch memory buffers |
| """ |
|
|
| from typing import Dict, List |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| class MemoryBuffer: |
| """ |
| A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying |
| memory. It must have a unique type to support this behavior. |
| """ |
|
|
| def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype): |
| self.numel = numel |
| self.numel_padded = numel_padded |
| self.dtype = dtype |
| self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device='cuda', requires_grad=False) |
|
|
| def zero(self): |
| """Reset the buffer to zero.""" |
| self.data.zero_() |
|
|
| def get(self, shape, start_index): |
| """Return a tensor with the input `shape` as a view into the |
| 1-D data starting at `start_index`.""" |
| end_index = start_index + shape.numel() |
| assert end_index <= self.numel, \ |
| 'requested tensor is out of the buffer range.' |
| buffer_tensor = self.data[start_index:end_index] |
| buffer_tensor = buffer_tensor.view(shape) |
| return buffer_tensor |
|
|
|
|
| def calc_padded_numel(shape: torch.Size, dtype: torch.dtype): |
| """for cuda memory alignment, make sure alignment by 128-bits""" |
| align_numel = 128 // torch.finfo(dtype).bits |
| numel = shape.numel() |
| return (numel + align_numel - 1) // align_numel * align_numel |
|
|
|
|
| def get_weight_buffer_meta_from_module(module: nn.Module) -> Dict[str, Dict]: |
| """ |
| Return a dictionary containing name to a shape and dtype. |
| """ |
| weight_buffer_meta = {} |
| for name, param in sorted(module.named_parameters()): |
| weight_buffer_meta[name] = {'shape': param.shape, 'dtype': param.dtype} |
| return weight_buffer_meta |
|
|
|
|
| def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype, MemoryBuffer]: |
| """Build the memory buffer given weight_buffer_meta |
| |
| Args: |
| weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors |
| |
| Returns: a large memory buffer for each dtype that can hold all the tensors |
| |
| """ |
| memory_buffers = {} |
| total_numel_map = {} |
| for name, meta_info in sorted(weight_buffer_meta.items()): |
| shape = meta_info['shape'] |
| dtype = meta_info['dtype'] |
|
|
| assert isinstance(shape, torch.Size) |
| assert isinstance(dtype, torch.dtype) |
|
|
| if dtype not in total_numel_map: |
| total_numel_map[dtype] = 0 |
|
|
| total_numel_map[dtype] += calc_padded_numel(shape, dtype) |
|
|
| for dtype, total_numel in total_numel_map.items(): |
| memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype) |
|
|
| return memory_buffers |
|
|
|
|
| def build_memory_reference_from_module(module: torch.nn.Module, |
| memory_buffers: Dict[torch.dtype, MemoryBuffer], |
| maintain_weight=True): |
| start_index = {} |
| for dtype in memory_buffers.keys(): |
| start_index[dtype] = 0 |
| for name, param in sorted(module.named_parameters()): |
| memory_buffer = memory_buffers[param.dtype] |
| buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype]) |
| |
| start_index[param.dtype] += calc_padded_numel(param.shape, dtype) |
| if maintain_weight: |
| buffer.copy_(param.data) |
| param.data = buffer |
|
|
|
|
| def build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer]): |
| """Build the memory references. The memory buffers are built using the build_memory_buffer API. |
| This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta. |
| |
| Args: |
| weight_buffer_meta: |
| memory_buffers: |
| |
| Returns: |
| |
| """ |
| start_idx = {} |
| weight_buffers = {} |
| for dtype in memory_buffers.keys(): |
| start_idx[dtype] = 0 |
|
|
| for name, meta_info in sorted(weight_buffer_meta.items()): |
| shape = meta_info['shape'] |
| dtype = meta_info['dtype'] |
|
|
| buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype]) |
| start_idx[dtype] += calc_padded_numel(shape, dtype) |
| weight_buffers[name] = buffer |
|
|
| return weight_buffers |
|
|
|
|
| class MemoryBufferModuleWrapper: |
| """ |
| Note that we do not design MemoryBufferModuleWrapper as an nn.Module due to |
| - It will change the checkpoint name |
| """ |
|
|
| def __init__(self, module: nn.Module): |
| super().__init__() |
| self.module = module |
| self.weight_buffer_meta = get_weight_buffer_meta_from_module(self.module) |
| self.memory_buffers = build_memory_buffer(self.weight_buffer_meta) |
| build_memory_reference_from_module(self.module, self.memory_buffers) |
|
|
| def get_memory_buffers(self): |
| return self.memory_buffers |
|
|
| def get_weight_buffer_meta(self): |
| return self.weight_buffer_meta |
|
|
|
|
| class MegatronMemoryBufferForRollout(object): |
| """ |
| We assume that |
| - inference engine has tp + dp |
| - actor has tp + pp + dp |
| - the tp between inference engine and actor should be the same |
| - memory_buffers: contains a list of memory_buffers, each is a dict from dtype to MemoryBuffer |
| - weight_buffers: contains a list of weight_buffers, each is a dict from name to param |
| - named_parameters: a dict from name to parameter that normalizes the names from pp and vpp. Note that |
| the named_parameters may not be directly compatible with inference engine. User has to take care of |
| this part such as the layout mismatches. (e.g. qkv transpose) |
| - Note that weight_buffer, named_parameters and memory_buffers share the same underlying GPU memory. |
| - When doing weight sync, the data is transfer via memory buffers |
| """ |
|
|
| def __init__(self, transform_memory_param_fn): |
| self._memory_buffers = [] |
| self._weight_buffers = [] |
| self._named_parameters = {} |
| self.transform_memory_param_fn = transform_memory_param_fn |
|
|
| def initialize_weight_buffer(self, weight_buffer_meta_pp: List[Dict[str, Dict]]): |
| """ |
| Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct |
| a large buffer for each dtype in the weight_buffer. |
| |
| Args: |
| weight_buffer_meta: contains pp models, each pp models contains a dictionary of mapping from |
| |
| Returns: None |
| |
| """ |
| self.weight_buffer_meta_pp = weight_buffer_meta_pp |
|
|
| for weight_buffer_meta in self.weight_buffer_meta_pp: |
| memory_buffer = build_memory_buffer(weight_buffer_meta) |
| self._memory_buffers.append(memory_buffer) |
| self._weight_buffers.append(None) |
|
|
| def build_memory_reference(self): |
| for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp): |
| self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i]) |
| self._named_parameters = self.transform_memory_param_fn(self._weight_buffers) |
|
|
| @property |
| def named_parameters(self): |
| return self._named_parameters |
|
|
| @property |
| def weight_buffers(self): |
| return self._weight_buffers |
|
|
| @property |
| def memory_buffers(self): |
| return self._memory_buffers |
|
|