|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
This file contains utilities to manipulate torch memory buffers
|
|
|
"""
|
|
|
|
|
|
from typing import Dict, List, Optional
|
|
|
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
class MemoryBuffer:
|
|
|
"""
|
|
|
A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying
|
|
|
memory. It must have a unique type to support this behavior.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: Optional[torch.Tensor] = None):
|
|
|
self.numel = numel
|
|
|
self.numel_padded = numel_padded
|
|
|
self.dtype = dtype
|
|
|
if source is not None:
|
|
|
self.data = source
|
|
|
else:
|
|
|
self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device="cuda", requires_grad=False)
|
|
|
|
|
|
def zero(self):
|
|
|
"""Reset the buffer to zero."""
|
|
|
self.data.zero_()
|
|
|
|
|
|
def get(self, shape, start_index):
|
|
|
"""Return a tensor with the input `shape` as a view into the
|
|
|
1-D data starting at `start_index`."""
|
|
|
end_index = start_index + shape.numel()
|
|
|
assert end_index <= self.numel, "requested tensor is out of the buffer range."
|
|
|
buffer_tensor = self.data[start_index:end_index]
|
|
|
buffer_tensor = buffer_tensor.view(shape)
|
|
|
return buffer_tensor
|
|
|
|
|
|
|
|
|
def calc_padded_numel(shape: torch.Size, dtype: torch.dtype):
|
|
|
"""for cuda memory alignment, make sure alignment by 128-bits"""
|
|
|
align_numel = 128 // torch.finfo(dtype).bits
|
|
|
numel = shape.numel()
|
|
|
return (numel + align_numel - 1) // align_numel * align_numel
|
|
|
|
|
|
|
|
|
def get_weight_buffer_meta_from_module(module: nn.Module) -> Dict[str, Dict]:
|
|
|
"""
|
|
|
Return a dictionary containing name to a shape and dtype.
|
|
|
"""
|
|
|
weight_buffer_meta = {}
|
|
|
for name, param in sorted(module.named_parameters()):
|
|
|
weight_buffer_meta[name] = {"shape": param.shape, "dtype": param.dtype}
|
|
|
return weight_buffer_meta
|
|
|
|
|
|
|
|
|
def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype, MemoryBuffer]:
|
|
|
"""Build the memory buffer given weight_buffer_meta
|
|
|
|
|
|
Args:
|
|
|
weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors
|
|
|
|
|
|
Returns: a large memory buffer for each dtype that can hold all the tensors
|
|
|
|
|
|
"""
|
|
|
memory_buffers = {}
|
|
|
total_numel_map = {}
|
|
|
for name, meta_info in sorted(weight_buffer_meta.items()):
|
|
|
shape = meta_info["shape"]
|
|
|
dtype = meta_info["dtype"]
|
|
|
|
|
|
assert isinstance(shape, torch.Size)
|
|
|
assert isinstance(dtype, torch.dtype)
|
|
|
|
|
|
if dtype not in total_numel_map:
|
|
|
total_numel_map[dtype] = 0
|
|
|
|
|
|
total_numel_map[dtype] += calc_padded_numel(shape, dtype)
|
|
|
|
|
|
for dtype, total_numel in total_numel_map.items():
|
|
|
memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype)
|
|
|
|
|
|
return memory_buffers
|
|
|
|
|
|
|
|
|
def build_memory_reference_from_module(module: torch.nn.Module, memory_buffers: Dict[torch.dtype, MemoryBuffer], maintain_weight=True):
|
|
|
start_index = {}
|
|
|
for dtype in memory_buffers:
|
|
|
start_index[dtype] = 0
|
|
|
for name, param in sorted(module.named_parameters()):
|
|
|
memory_buffer = memory_buffers[param.dtype]
|
|
|
buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype])
|
|
|
|
|
|
start_index[param.dtype] += calc_padded_numel(param.shape, dtype)
|
|
|
if maintain_weight:
|
|
|
buffer.copy_(param.data)
|
|
|
param.data = buffer
|
|
|
|
|
|
|
|
|
def build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer]):
|
|
|
"""Build the memory references. The memory buffers are built using the build_memory_buffer API.
|
|
|
This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta.
|
|
|
|
|
|
Args:
|
|
|
weight_buffer_meta:
|
|
|
memory_buffers:
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
"""
|
|
|
start_idx = {}
|
|
|
weight_buffers = {}
|
|
|
for dtype in memory_buffers:
|
|
|
start_idx[dtype] = 0
|
|
|
|
|
|
for name, meta_info in sorted(weight_buffer_meta.items()):
|
|
|
shape = meta_info["shape"]
|
|
|
dtype = meta_info["dtype"]
|
|
|
|
|
|
buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype])
|
|
|
start_idx[dtype] += calc_padded_numel(shape, dtype)
|
|
|
weight_buffers[name] = buffer
|
|
|
|
|
|
return weight_buffers
|
|
|
|
|
|
|
|
|
class MemoryBufferModuleWrapper:
|
|
|
"""
|
|
|
Note that we do not design MemoryBufferModuleWrapper as an nn.Module due to
|
|
|
- It will change the checkpoint name
|
|
|
"""
|
|
|
|
|
|
def __init__(self, module: nn.Module):
|
|
|
super().__init__()
|
|
|
self.module = module
|
|
|
self.weight_buffer_meta = get_weight_buffer_meta_from_module(self.module)
|
|
|
self.memory_buffers = build_memory_buffer(self.weight_buffer_meta)
|
|
|
build_memory_reference_from_module(self.module, self.memory_buffers)
|
|
|
|
|
|
def get_memory_buffers(self):
|
|
|
return self.memory_buffers
|
|
|
|
|
|
def get_weight_buffer_meta(self):
|
|
|
return self.weight_buffer_meta
|
|
|
|
|
|
|
|
|
class MegatronMemoryBufferForRollout:
|
|
|
"""
|
|
|
We assume that
|
|
|
- inference engine has tp + dp
|
|
|
- actor has tp + pp + dp
|
|
|
- the tp between inference engine and actor should be the same
|
|
|
- memory_buffers: contains a list of memory_buffers, each is a dict from dtype to MemoryBuffer
|
|
|
- weight_buffers: contains a list of weight_buffers, each is a dict from name to param
|
|
|
- named_parameters: a dict from name to parameter that normalizes the names from pp and vpp. Note that
|
|
|
the named_parameters may not be directly compatible with inference engine. User has to take care of
|
|
|
this part such as the layout mismatches. (e.g. qkv transpose)
|
|
|
- Note that weight_buffer, named_parameters and memory_buffers share the same underlying GPU memory.
|
|
|
- When doing weight sync, the data is transfer via memory buffers
|
|
|
"""
|
|
|
|
|
|
def __init__(self, transform_memory_param_fn):
|
|
|
self._memory_buffers = []
|
|
|
self._weight_buffers = []
|
|
|
self._named_parameters = {}
|
|
|
self.transform_memory_param_fn = transform_memory_param_fn
|
|
|
|
|
|
def initialize_weight_buffer(self, weight_buffer_meta_pp: List[Dict[str, Dict]]):
|
|
|
"""
|
|
|
Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct
|
|
|
a large buffer for each dtype in the weight_buffer.
|
|
|
|
|
|
Args:
|
|
|
weight_buffer_meta: contains pp models, each pp models contains a dictionary of mapping from
|
|
|
|
|
|
Returns: None
|
|
|
|
|
|
"""
|
|
|
self.weight_buffer_meta_pp = weight_buffer_meta_pp
|
|
|
|
|
|
for weight_buffer_meta in self.weight_buffer_meta_pp:
|
|
|
memory_buffer = build_memory_buffer(weight_buffer_meta)
|
|
|
self._memory_buffers.append(memory_buffer)
|
|
|
self._weight_buffers.append(None)
|
|
|
|
|
|
def build_memory_reference(self):
|
|
|
for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp):
|
|
|
self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i])
|
|
|
self._named_parameters = self.transform_memory_param_fn(self._weight_buffers)
|
|
|
|
|
|
@property
|
|
|
def named_parameters(self):
|
|
|
return self._named_parameters
|
|
|
|
|
|
@property
|
|
|
def weight_buffers(self):
|
|
|
return self._weight_buffers
|
|
|
|
|
|
@property
|
|
|
def memory_buffers(self):
|
|
|
return self._memory_buffers
|
|
|
|