File size: 5,124 Bytes
fb11af9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | from typing import Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from ...data.constants import IGNORE_INDEX
from .comm import get_ulysses_sequence_parallel_group, get_unified_sequence_parallel_group
from .ulysses import _Gather, _Slice
from .utils import pad_tensor, unpadding_tensor_for_seqeunce_parallel
def slice_input_tensor(
x: Tensor,
dim: int,
padding: bool = True,
padding_value: int = 0,
group: ProcessGroup = None,
) -> Tensor:
"""
A func to slice the input sequence in sequence parallel
"""
group = get_unified_sequence_parallel_group() if group is None else group
if not group:
return x
sp_rank = dist.get_rank(group)
sp_world = dist.get_world_size(group)
dim_size = x.shape[dim]
unit = (dim_size + sp_world - 1) // sp_world
if padding and dim_size % sp_world:
padding_size = sp_world - (dim_size % sp_world)
x = pad_tensor(x, dim, padding_size, padding_value)
slc = [slice(None)] * len(x.shape)
slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1))
return x[slc].contiguous()
def slice_input_tensor_scale_grad(
x: Tensor,
dim: int,
group: ProcessGroup = None,
scale_grad=True,
):
"""
A func to gather the outputs for the model result in sequence parallel
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
if not group:
return x
x = _Slice.apply(group, x, dim, scale_grad)
return x
def gather_outputs(
x: Tensor,
gather_dim: int,
padding_dim: Optional[int] = None,
unpad_dim_size: Optional[int] = None,
scale_grad=True,
group: ProcessGroup = None,
):
"""
A func to gather the outputs for the model result in sequence parallel
"""
group = get_unified_sequence_parallel_group() if group is None else group
if not group:
return x
x = _Gather.apply(group, x, gather_dim, scale_grad)
if padding_dim is not None:
x = unpadding_tensor_for_seqeunce_parallel(x, padding_dim, unpad_dim_size, group)
return x
def slice_position_embedding(position_embeddings: tuple, dim: int = 1, sp_group: dist.ProcessGroup = None):
"""
Forward hook for LlamaRotaryEmbedding to apply Ulysses tensor slicing.
Args:
position_embeddings: Input tensors to the forward method
dim: The dimension to slice
sp_group: The sequence parallel group
Returns:
Modified (cos, sin) tuple with slicing applied if ulysses is enabled
"""
if sp_group is not None:
cos, sin = position_embeddings
cos = slice_input_tensor(cos, dim=dim, padding=False, group=sp_group)
sin = slice_input_tensor(sin, dim=dim, padding=False, group=sp_group)
return (cos, sin)
return position_embeddings
def sequence_parallel_preprocess(
input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
sp_group: Optional[ProcessGroup] = None,
):
"""
Preprocess input_ids and labels for sequence parallel training.
Args:
input_ids: Input token ids
labels: Label token ids
position_ids: Position ids
attention_mask: Attention mask
cu_seqlens: Cumulative sequence lengths
Returns:
Preprocessed input_ids, labels, position_ids, attention_mask, cu_seqlens
"""
if sp_group is not None:
sp_size = dist.get_world_size(sp_group)
padding_size = (sp_size - (input_ids.shape[-1] % sp_size)) % sp_size
# Slice input_ids among sequence parallel group
input_ids = slice_input_tensor(input_ids, dim=-1, padding=True, padding_value=0, group=sp_group)
# Slice labels among sequence parallel group
if labels is not None:
labels = labels[..., 1:].contiguous() # shift labels
labels = F.pad(labels, (0, 1), "constant", IGNORE_INDEX) # pad to the same length as input_ids
labels = slice_input_tensor(labels, dim=-1, padding=True, padding_value=IGNORE_INDEX, group=sp_group)
# Padding position_ids
if position_ids is not None:
position_ids = pad_tensor(position_ids, dim=-1, padding_size=padding_size, padding_value=0)
# Padding attention_mask
if attention_mask is not None:
attn_mask_padding_value = 1 if position_ids is not None else 0
attention_mask = pad_tensor(
attention_mask, dim=-1, padding_size=padding_size, padding_value=attn_mask_padding_value
)
# Padding cu_seqlens
if cu_seqlens is not None:
cu_seqlens_padding_value = cu_seqlens[-1].item() + padding_size
cu_seqlens = pad_tensor(
cu_seqlens, dim=-1, padding_size=padding_size, padding_value=cu_seqlens_padding_value
)
return input_ids, labels, position_ids, attention_mask, cu_seqlens
|