|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Utilities for DeepSpeed Ulysses Sequence Parallelism. |
|
|
DeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509 |
|
|
Inspired from: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py |
|
|
""" |
|
|
|
|
|
from typing import Any, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch import Tensor |
|
|
from torch.distributed import ProcessGroup |
|
|
|
|
|
|
|
|
_ULYSSES_SEQUENCE_PARALLEL_GROUP = None |
|
|
|
|
|
|
|
|
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup): |
|
|
""" |
|
|
Set ulysses sequence parallel process group. |
|
|
""" |
|
|
global _ULYSSES_SEQUENCE_PARALLEL_GROUP |
|
|
_ULYSSES_SEQUENCE_PARALLEL_GROUP = group |
|
|
|
|
|
|
|
|
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]: |
|
|
""" |
|
|
Get ulysses sequence parallel process group. |
|
|
""" |
|
|
global _ULYSSES_SEQUENCE_PARALLEL_GROUP |
|
|
return _ULYSSES_SEQUENCE_PARALLEL_GROUP |
|
|
|
|
|
|
|
|
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int: |
|
|
""" |
|
|
Get ulysses sequence parallel world size. |
|
|
""" |
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group |
|
|
return dist.get_world_size(group) if group else 1 |
|
|
|
|
|
|
|
|
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int: |
|
|
""" |
|
|
Get ulysses sequence parallel rank. |
|
|
""" |
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group |
|
|
return dist.get_rank(group) if group else 0 |
|
|
|
|
|
|
|
|
def gather_seq_scatter_heads( |
|
|
x: Tensor, |
|
|
seq_dim: int, |
|
|
head_dim: int, |
|
|
unpadded_dim_size: int = 0, |
|
|
group: ProcessGroup = None, |
|
|
) -> Tensor: |
|
|
""" |
|
|
A func to sync embedding input with alltoall in sequence parallel |
|
|
gather sequence dimension and scatter head dim: |
|
|
e.g. seq_dim: 1, head_dim: 2 |
|
|
[bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...] |
|
|
""" |
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group |
|
|
if not group: |
|
|
return x |
|
|
sp_world = get_ulysses_sequence_parallel_world_size(group) |
|
|
x = SeqAllToAll.apply(group, x, head_dim, seq_dim) |
|
|
if unpadded_dim_size and unpadded_dim_size % sp_world != 0: |
|
|
padding_size = x.size(seq_dim) - unpadded_dim_size |
|
|
x = _unpad_tensor(x, seq_dim, padding_size) |
|
|
return x |
|
|
|
|
|
|
|
|
def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor: |
|
|
""" |
|
|
A func to sync attention result with alltoall in sequence parallel |
|
|
gather head dimension and scatter seq dim: |
|
|
e.g. seq_dim: 1, head_dim: 2 |
|
|
[bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...] |
|
|
""" |
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group |
|
|
if not group: |
|
|
return x |
|
|
dim_size = x.size(seq_dim) |
|
|
sp_world = get_ulysses_sequence_parallel_world_size(group) |
|
|
if dim_size % sp_world != 0: |
|
|
padding_size = sp_world - (dim_size % sp_world) |
|
|
x = _pad_tensor(x, seq_dim, padding_size) |
|
|
return SeqAllToAll.apply(group, x, seq_dim, head_dim, False) |
|
|
|
|
|
|
|
|
def _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor: |
|
|
shape = list(x.shape) |
|
|
shape[dim] = padding_size |
|
|
pad = torch.zeros(shape, dtype=x.dtype, device=x.device) |
|
|
return torch.cat([x, pad], dim=dim) |
|
|
|
|
|
|
|
|
def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor: |
|
|
slc = [slice(None)] * len(x.shape) |
|
|
slc[dim] = slice(0, -padding_size) |
|
|
return x[slc] |
|
|
|
|
|
|
|
|
def slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor: |
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group |
|
|
sp_world_size = dist.get_world_size(group) |
|
|
sp_rank = get_ulysses_sequence_parallel_rank() |
|
|
dim_size = x.size(dim) |
|
|
|
|
|
if padding and dim_size % sp_world_size: |
|
|
padding_size = sp_world_size - (dim_size % sp_world_size) |
|
|
x = _pad_tensor(x, dim, padding_size) |
|
|
|
|
|
parts = x.size(dim) // sp_world_size |
|
|
slc = [slice(None)] * len(x.shape) |
|
|
slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts) |
|
|
return x[slc].contiguous() |
|
|
|
|
|
|
|
|
def all_to_all_tensor( |
|
|
local_input: Tensor, |
|
|
scatter_dim: int, |
|
|
gather_dim: int, |
|
|
group: Optional[dist.ProcessGroup] = None, |
|
|
async_op: bool = False, |
|
|
): |
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group |
|
|
seq_world_size = dist.get_world_size(group) |
|
|
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)] |
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] |
|
|
comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op) |
|
|
if async_op: |
|
|
|
|
|
def wait(): |
|
|
comm.wait() |
|
|
return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
|
return wait |
|
|
return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
|
|
|
|
def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False): |
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group |
|
|
sp_world_size = dist.get_world_size(group=group) |
|
|
output_shape = list(local_tensor.shape) |
|
|
output_shape[0] = output_shape[0] * sp_world_size |
|
|
output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device) |
|
|
dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op) |
|
|
return output |
|
|
|
|
|
|
|
|
class SeqAllToAll(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx: Any, |
|
|
group: dist.ProcessGroup, |
|
|
local_input: Tensor, |
|
|
scatter_dim: int, |
|
|
gather_dim: int, |
|
|
async_op: bool = False, |
|
|
) -> Tensor: |
|
|
ctx.group = group |
|
|
ctx.scatter_dim = scatter_dim |
|
|
ctx.gather_dim = gather_dim |
|
|
ctx.async_op = async_op |
|
|
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: |
|
|
if ctx.async_op: |
|
|
input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() |
|
|
else: |
|
|
input_t = grad_output[0] |
|
|
return ( |
|
|
None, |
|
|
all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False), |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
class Gather(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx: Any, |
|
|
group: dist.ProcessGroup, |
|
|
local_tensor: Tensor, |
|
|
gather_dim: int, |
|
|
grad_scaler: bool = True, |
|
|
async_op=False, |
|
|
) -> Tensor: |
|
|
ctx.group = group |
|
|
ctx.gather_dim = gather_dim |
|
|
ctx.grad_scaler = grad_scaler |
|
|
ctx.async_op = async_op |
|
|
|
|
|
sp_world_size = dist.get_world_size(group=group) |
|
|
ctx.sp_world_size = sp_world_size |
|
|
|
|
|
sp_rank = dist.get_rank(group=group) |
|
|
ctx.sp_rank = sp_rank |
|
|
|
|
|
local_shape = list(local_tensor.size()) |
|
|
split_size = local_shape[0] |
|
|
part_size = local_shape[gather_dim] |
|
|
ctx.part_size = part_size |
|
|
|
|
|
output = all_gather_tensor(local_tensor, group, async_op) |
|
|
return torch.cat(output.split(split_size, dim=0), dim=gather_dim) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx: Any, grad_output: Tensor) -> Any: |
|
|
if ctx.grad_scaler: |
|
|
grad_output = grad_output * ctx.sp_world_size |
|
|
return ( |
|
|
None, |
|
|
grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(), |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
def gather_outputs_and_unpad( |
|
|
x: Tensor, |
|
|
gather_dim: int, |
|
|
unpad_dim: int = None, |
|
|
padding_size: int = 0, |
|
|
grad_scaler: bool = True, |
|
|
group: Optional[dist.ProcessGroup] = None, |
|
|
): |
|
|
group = get_ulysses_sequence_parallel_group() if group is None else group |
|
|
|
|
|
if group is None: |
|
|
return x |
|
|
x = Gather.apply(group, x, gather_dim, grad_scaler) |
|
|
if unpad_dim is not None: |
|
|
assert isinstance(padding_size, int), "padding size is not given or is not an integer" |
|
|
if padding_size == 0: |
|
|
return x |
|
|
x = _unpad_tensor(x, unpad_dim, padding_size) |
|
|
return x |
|
|
|
|
|
|
|
|
def ulysses_pad_and_slice_inputs( |
|
|
input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1 |
|
|
): |
|
|
""" |
|
|
Pad and slice input_ids to be divisible by sp_size |
|
|
Pad position_ids to be divisible by sp_size. |
|
|
|
|
|
Note both input_ids_rmpad and position_ids_rmpad will be padded, |
|
|
but only input_ids will be sliced. |
|
|
|
|
|
The is the utility of pre-forward for ulysses sequence parallelism |
|
|
|
|
|
Args: |
|
|
input_ids_rmpad: shape of [bsz, seqlen] |
|
|
position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1 |
|
|
sp_size (int): ulysses sequence parallelism size |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: padded and sliced input_ids |
|
|
torch.Tensor: padded and sliced position_ids |
|
|
int: pad size |
|
|
""" |
|
|
if position_ids_rmpad is not None: |
|
|
assert position_ids_rmpad.size(0) == 1 |
|
|
assert input_ids_rmpad.size(1) == position_ids_rmpad.size(1) |
|
|
if sp_size <= 1: |
|
|
return input_ids_rmpad, position_ids_rmpad, 0 |
|
|
_, total_seq_len = input_ids_rmpad.shape |
|
|
pad_size = (sp_size - total_seq_len % sp_size) % sp_size |
|
|
if pad_size > 0: |
|
|
input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0) |
|
|
if position_ids_rmpad is not None: |
|
|
pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0) |
|
|
position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1) |
|
|
|
|
|
input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False) |
|
|
return input_ids_rmpad, position_ids_rmpad, pad_size |
|
|
|