|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
Utilities for DeepSpeed Ulysses Sequence Parallelism.
|
|
|
DeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509
|
|
|
Inspired from: https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/sequence/layer.py
|
|
|
"""
|
|
|
|
|
|
from typing import Any, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
|
import torch.distributed as dist
|
|
|
from torch import Tensor
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
|
_ULYSSES_SEQUENCE_PARALLEL_GROUP = None
|
|
|
|
|
|
|
|
|
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):
|
|
|
"""
|
|
|
Set ulysses sequence parallel process group.
|
|
|
"""
|
|
|
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
|
|
_ULYSSES_SEQUENCE_PARALLEL_GROUP = group
|
|
|
|
|
|
|
|
|
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
|
|
|
"""
|
|
|
Get ulysses sequence parallel process group.
|
|
|
"""
|
|
|
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
|
|
return _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
|
|
|
|
|
|
|
|
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
|
|
|
"""
|
|
|
Get ulysses sequence parallel world size.
|
|
|
"""
|
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
|
|
return dist.get_world_size(group) if group else 1
|
|
|
|
|
|
|
|
|
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
|
|
|
"""
|
|
|
Get ulysses sequence parallel rank.
|
|
|
"""
|
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
|
|
return dist.get_rank(group) if group else 0
|
|
|
|
|
|
|
|
|
def gather_seq_scatter_heads(
|
|
|
x: Tensor,
|
|
|
seq_dim: int,
|
|
|
head_dim: int,
|
|
|
unpadded_dim_size: int = 0,
|
|
|
group: ProcessGroup = None,
|
|
|
) -> Tensor:
|
|
|
"""
|
|
|
A func to sync embedding input with alltoall in sequence parallel
|
|
|
gather sequence dimension and scatter head dim:
|
|
|
e.g. seq_dim: 1, head_dim: 2
|
|
|
[bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...]
|
|
|
"""
|
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
|
|
if not group:
|
|
|
return x
|
|
|
sp_world = get_ulysses_sequence_parallel_world_size(group)
|
|
|
x = SeqAllToAll.apply(group, x, head_dim, seq_dim)
|
|
|
if unpadded_dim_size and unpadded_dim_size % sp_world != 0:
|
|
|
padding_size = x.size(seq_dim) - unpadded_dim_size
|
|
|
x = _unpad_tensor(x, seq_dim, padding_size)
|
|
|
return x
|
|
|
|
|
|
|
|
|
def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor:
|
|
|
"""
|
|
|
A func to sync attention result with alltoall in sequence parallel
|
|
|
gather head dimension and scatter seq dim:
|
|
|
e.g. seq_dim: 1, head_dim: 2
|
|
|
[bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...]
|
|
|
"""
|
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
|
|
if not group:
|
|
|
return x
|
|
|
dim_size = x.size(seq_dim)
|
|
|
sp_world = get_ulysses_sequence_parallel_world_size(group)
|
|
|
if dim_size % sp_world != 0:
|
|
|
padding_size = sp_world - (dim_size % sp_world)
|
|
|
x = _pad_tensor(x, seq_dim, padding_size)
|
|
|
return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
|
|
|
|
|
|
|
|
|
def _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
|
|
|
shape = list(x.shape)
|
|
|
shape[dim] = padding_size
|
|
|
pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
|
|
return torch.cat([x, pad], dim=dim)
|
|
|
|
|
|
|
|
|
def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
|
|
|
slc = [slice(None)] * len(x.shape)
|
|
|
slc[dim] = slice(0, -padding_size)
|
|
|
return x[slc]
|
|
|
|
|
|
|
|
|
def slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor:
|
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
|
|
sp_world_size = dist.get_world_size(group)
|
|
|
sp_rank = get_ulysses_sequence_parallel_rank()
|
|
|
dim_size = x.size(dim)
|
|
|
|
|
|
if padding and dim_size % sp_world_size:
|
|
|
padding_size = sp_world_size - (dim_size % sp_world_size)
|
|
|
x = _pad_tensor(x, dim, padding_size)
|
|
|
|
|
|
parts = x.size(dim) // sp_world_size
|
|
|
slc = [slice(None)] * len(x.shape)
|
|
|
slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts)
|
|
|
return x[slc].contiguous()
|
|
|
|
|
|
|
|
|
def all_to_all_tensor(
|
|
|
local_input: Tensor,
|
|
|
scatter_dim: int,
|
|
|
gather_dim: int,
|
|
|
group: Optional[dist.ProcessGroup] = None,
|
|
|
async_op: bool = False,
|
|
|
):
|
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
|
|
seq_world_size = dist.get_world_size(group)
|
|
|
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
|
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
|
|
|
comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
|
|
|
if async_op:
|
|
|
|
|
|
def wait():
|
|
|
comm.wait()
|
|
|
return torch.cat(output_list, dim=gather_dim).contiguous()
|
|
|
|
|
|
return wait
|
|
|
return torch.cat(output_list, dim=gather_dim).contiguous()
|
|
|
|
|
|
|
|
|
def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False):
|
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
|
|
sp_world_size = dist.get_world_size(group=group)
|
|
|
output_shape = list(local_tensor.shape)
|
|
|
output_shape[0] = output_shape[0] * sp_world_size
|
|
|
output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device)
|
|
|
dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op)
|
|
|
return output
|
|
|
|
|
|
|
|
|
class SeqAllToAll(torch.autograd.Function):
|
|
|
@staticmethod
|
|
|
def forward(
|
|
|
ctx: Any,
|
|
|
group: dist.ProcessGroup,
|
|
|
local_input: Tensor,
|
|
|
scatter_dim: int,
|
|
|
gather_dim: int,
|
|
|
async_op: bool = False,
|
|
|
) -> Tensor:
|
|
|
ctx.group = group
|
|
|
ctx.scatter_dim = scatter_dim
|
|
|
ctx.gather_dim = gather_dim
|
|
|
ctx.async_op = async_op
|
|
|
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)
|
|
|
|
|
|
@staticmethod
|
|
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
|
|
input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() if ctx.async_op else grad_output[0]
|
|
|
return (
|
|
|
None,
|
|
|
all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),
|
|
|
None,
|
|
|
None,
|
|
|
None,
|
|
|
None,
|
|
|
)
|
|
|
|
|
|
|
|
|
class Gather(torch.autograd.Function):
|
|
|
@staticmethod
|
|
|
def forward(
|
|
|
ctx: Any,
|
|
|
group: dist.ProcessGroup,
|
|
|
local_tensor: Tensor,
|
|
|
gather_dim: int,
|
|
|
grad_scaler: bool = True,
|
|
|
async_op=False,
|
|
|
) -> Tensor:
|
|
|
ctx.group = group
|
|
|
ctx.gather_dim = gather_dim
|
|
|
ctx.grad_scaler = grad_scaler
|
|
|
ctx.async_op = async_op
|
|
|
|
|
|
sp_world_size = dist.get_world_size(group=group)
|
|
|
ctx.sp_world_size = sp_world_size
|
|
|
|
|
|
sp_rank = dist.get_rank(group=group)
|
|
|
ctx.sp_rank = sp_rank
|
|
|
|
|
|
local_shape = list(local_tensor.size())
|
|
|
split_size = local_shape[0]
|
|
|
part_size = local_shape[gather_dim]
|
|
|
ctx.part_size = part_size
|
|
|
|
|
|
output = all_gather_tensor(local_tensor, group, async_op)
|
|
|
return torch.cat(output.split(split_size, dim=0), dim=gather_dim)
|
|
|
|
|
|
@staticmethod
|
|
|
def backward(ctx: Any, grad_output: Tensor) -> Any:
|
|
|
if ctx.grad_scaler:
|
|
|
grad_output = grad_output * ctx.sp_world_size
|
|
|
return (
|
|
|
None,
|
|
|
grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(),
|
|
|
None,
|
|
|
None,
|
|
|
None,
|
|
|
None,
|
|
|
)
|
|
|
|
|
|
|
|
|
def gather_outpus_and_unpad(
|
|
|
x: Tensor,
|
|
|
gather_dim: int,
|
|
|
unpad_dim: int = None,
|
|
|
padding_size: int = 0,
|
|
|
grad_scaler: bool = True,
|
|
|
group: Optional[dist.ProcessGroup] = None,
|
|
|
):
|
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
|
|
if group is None:
|
|
|
return x
|
|
|
x = Gather.apply(group, x, gather_dim, grad_scaler)
|
|
|
if unpad_dim is not None:
|
|
|
assert isinstance(padding_size, int), "padding size is not given or is not an integer"
|
|
|
if padding_size == 0:
|
|
|
return x
|
|
|
x = _unpad_tensor(x, unpad_dim, padding_size)
|
|
|
return x
|
|
|
|
|
|
|
|
|
def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1):
|
|
|
"""
|
|
|
Pad and slice input_ids to be divisible by sp_size
|
|
|
Pad position_ids to be divisible by sp_size.
|
|
|
|
|
|
Note both input_ids_rmpad and position_ids_rmpad will be padded and sliced.
|
|
|
|
|
|
The is the utility of pre-forward for ulysses sequence parallelism
|
|
|
|
|
|
Args:
|
|
|
input_ids_rmpad: shape of [bsz, seqlen]
|
|
|
position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1
|
|
|
sp_size (int): ulysses sequence parallelism size
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: padded and sliced input_ids
|
|
|
torch.Tensor: padded and sliced position_ids
|
|
|
int: pad size
|
|
|
"""
|
|
|
if position_ids_rmpad is not None:
|
|
|
assert position_ids_rmpad.size(0) == 1
|
|
|
assert input_ids_rmpad.size(1) == position_ids_rmpad.size(1)
|
|
|
if sp_size <= 1:
|
|
|
return input_ids_rmpad, position_ids_rmpad, 0
|
|
|
_, total_seq_len = input_ids_rmpad.shape
|
|
|
pad_size = (sp_size - total_seq_len % sp_size) % sp_size
|
|
|
if pad_size > 0:
|
|
|
input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0)
|
|
|
if position_ids_rmpad is not None:
|
|
|
pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0)
|
|
|
position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1)
|
|
|
input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False)
|
|
|
if position_ids_rmpad is not None:
|
|
|
position_ids_rmpad = slice_input_tensor(position_ids_rmpad, dim=1, padding=False)
|
|
|
return input_ids_rmpad, position_ids_rmpad, pad_size
|
|
|
|
|
|
|
|
|
def validate_ulysses_config(num_heads, ulysses_sequence_size):
|
|
|
if ulysses_sequence_size > 1:
|
|
|
assert num_heads % ulysses_sequence_size == 0, f"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})"
|
|
|
|