| from typing import Optional |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| from torch import Tensor |
| from torch.distributed import ProcessGroup |
|
|
| from ...data.constants import IGNORE_INDEX |
| from .comm import get_ulysses_sequence_parallel_group, get_unified_sequence_parallel_group |
| from .ulysses import _Gather, _Slice |
| from .utils import pad_tensor, unpadding_tensor_for_seqeunce_parallel |
|
|
|
|
| def slice_input_tensor( |
| x: Tensor, |
| dim: int, |
| padding: bool = True, |
| padding_value: int = 0, |
| group: ProcessGroup = None, |
| ) -> Tensor: |
| """ |
| A func to slice the input sequence in sequence parallel |
| """ |
| group = get_unified_sequence_parallel_group() if group is None else group |
| if not group: |
| return x |
| sp_rank = dist.get_rank(group) |
| sp_world = dist.get_world_size(group) |
| dim_size = x.shape[dim] |
| unit = (dim_size + sp_world - 1) // sp_world |
| if padding and dim_size % sp_world: |
| padding_size = sp_world - (dim_size % sp_world) |
| x = pad_tensor(x, dim, padding_size, padding_value) |
| slc = [slice(None)] * len(x.shape) |
| slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1)) |
| return x[slc].contiguous() |
|
|
|
|
| def slice_input_tensor_scale_grad( |
| x: Tensor, |
| dim: int, |
| group: ProcessGroup = None, |
| scale_grad=True, |
| ): |
| """ |
| A func to gather the outputs for the model result in sequence parallel |
| """ |
| group = get_ulysses_sequence_parallel_group() if group is None else group |
| if not group: |
| return x |
| x = _Slice.apply(group, x, dim, scale_grad) |
| return x |
|
|
|
|
| def gather_outputs( |
| x: Tensor, |
| gather_dim: int, |
| padding_dim: Optional[int] = None, |
| unpad_dim_size: Optional[int] = None, |
| scale_grad=True, |
| group: ProcessGroup = None, |
| ): |
| """ |
| A func to gather the outputs for the model result in sequence parallel |
| """ |
| group = get_unified_sequence_parallel_group() if group is None else group |
| if not group: |
| return x |
| x = _Gather.apply(group, x, gather_dim, scale_grad) |
| if padding_dim is not None: |
| x = unpadding_tensor_for_seqeunce_parallel(x, padding_dim, unpad_dim_size, group) |
| return x |
|
|
|
|
| def slice_position_embedding(position_embeddings: tuple, dim: int = 1, sp_group: dist.ProcessGroup = None): |
| """ |
| Forward hook for LlamaRotaryEmbedding to apply Ulysses tensor slicing. |
| |
| Args: |
| position_embeddings: Input tensors to the forward method |
| dim: The dimension to slice |
| sp_group: The sequence parallel group |
| Returns: |
| Modified (cos, sin) tuple with slicing applied if ulysses is enabled |
| """ |
| if sp_group is not None: |
| cos, sin = position_embeddings |
| cos = slice_input_tensor(cos, dim=dim, padding=False, group=sp_group) |
| sin = slice_input_tensor(sin, dim=dim, padding=False, group=sp_group) |
| return (cos, sin) |
| return position_embeddings |
|
|
|
|
| def sequence_parallel_preprocess( |
| input_ids: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| sp_group: Optional[ProcessGroup] = None, |
| ): |
| """ |
| Preprocess input_ids and labels for sequence parallel training. |
| |
| Args: |
| input_ids: Input token ids |
| labels: Label token ids |
| position_ids: Position ids |
| attention_mask: Attention mask |
| cu_seqlens: Cumulative sequence lengths |
| |
| Returns: |
| Preprocessed input_ids, labels, position_ids, attention_mask, cu_seqlens |
| """ |
| if sp_group is not None: |
| sp_size = dist.get_world_size(sp_group) |
| padding_size = (sp_size - (input_ids.shape[-1] % sp_size)) % sp_size |
|
|
| |
| input_ids = slice_input_tensor(input_ids, dim=-1, padding=True, padding_value=0, group=sp_group) |
|
|
| |
| if labels is not None: |
| labels = labels[..., 1:].contiguous() |
| labels = F.pad(labels, (0, 1), "constant", IGNORE_INDEX) |
| labels = slice_input_tensor(labels, dim=-1, padding=True, padding_value=IGNORE_INDEX, group=sp_group) |
|
|
| |
| if position_ids is not None: |
| position_ids = pad_tensor(position_ids, dim=-1, padding_size=padding_size, padding_value=0) |
|
|
| |
| if attention_mask is not None: |
| attn_mask_padding_value = 1 if position_ids is not None else 0 |
| attention_mask = pad_tensor( |
| attention_mask, dim=-1, padding_size=padding_size, padding_value=attn_mask_padding_value |
| ) |
|
|
| |
| if cu_seqlens is not None: |
| cu_seqlens_padding_value = cu_seqlens[-1].item() + padding_size |
| cu_seqlens = pad_tensor( |
| cu_seqlens, dim=-1, padding_size=padding_size, padding_value=cu_seqlens_padding_value |
| ) |
|
|
| return input_ids, labels, position_ids, attention_mask, cu_seqlens |
|
|