bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from ...data.constants import IGNORE_INDEX
from .comm import get_ulysses_sequence_parallel_group, get_unified_sequence_parallel_group
from .ulysses import _Gather, _Slice
from .utils import pad_tensor, unpadding_tensor_for_seqeunce_parallel
def slice_input_tensor(
x: Tensor,
dim: int,
padding: bool = True,
padding_value: int = 0,
group: ProcessGroup = None,
) -> Tensor:
"""
A func to slice the input sequence in sequence parallel
"""
group = get_unified_sequence_parallel_group() if group is None else group
if not group:
return x
sp_rank = dist.get_rank(group)
sp_world = dist.get_world_size(group)
dim_size = x.shape[dim]
unit = (dim_size + sp_world - 1) // sp_world
if padding and dim_size % sp_world:
padding_size = sp_world - (dim_size % sp_world)
x = pad_tensor(x, dim, padding_size, padding_value)
slc = [slice(None)] * len(x.shape)
slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1))
return x[slc].contiguous()
def slice_input_tensor_scale_grad(
x: Tensor,
dim: int,
group: ProcessGroup = None,
scale_grad=True,
):
"""
A func to gather the outputs for the model result in sequence parallel
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
if not group:
return x
x = _Slice.apply(group, x, dim, scale_grad)
return x
def gather_outputs(
x: Tensor,
gather_dim: int,
padding_dim: Optional[int] = None,
unpad_dim_size: Optional[int] = None,
scale_grad=True,
group: ProcessGroup = None,
):
"""
A func to gather the outputs for the model result in sequence parallel
"""
group = get_unified_sequence_parallel_group() if group is None else group
if not group:
return x
x = _Gather.apply(group, x, gather_dim, scale_grad)
if padding_dim is not None:
x = unpadding_tensor_for_seqeunce_parallel(x, padding_dim, unpad_dim_size, group)
return x
def slice_position_embedding(position_embeddings: tuple, dim: int = 1, sp_group: dist.ProcessGroup = None):
"""
Forward hook for LlamaRotaryEmbedding to apply Ulysses tensor slicing.
Args:
position_embeddings: Input tensors to the forward method
dim: The dimension to slice
sp_group: The sequence parallel group
Returns:
Modified (cos, sin) tuple with slicing applied if ulysses is enabled
"""
if sp_group is not None:
cos, sin = position_embeddings
cos = slice_input_tensor(cos, dim=dim, padding=False, group=sp_group)
sin = slice_input_tensor(sin, dim=dim, padding=False, group=sp_group)
return (cos, sin)
return position_embeddings
def sequence_parallel_preprocess(
input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
sp_group: Optional[ProcessGroup] = None,
):
"""
Preprocess input_ids and labels for sequence parallel training.
Args:
input_ids: Input token ids
labels: Label token ids
position_ids: Position ids
attention_mask: Attention mask
cu_seqlens: Cumulative sequence lengths
Returns:
Preprocessed input_ids, labels, position_ids, attention_mask, cu_seqlens
"""
if sp_group is not None:
sp_size = dist.get_world_size(sp_group)
padding_size = (sp_size - (input_ids.shape[-1] % sp_size)) % sp_size
# Slice input_ids among sequence parallel group
input_ids = slice_input_tensor(input_ids, dim=-1, padding=True, padding_value=0, group=sp_group)
# Slice labels among sequence parallel group
if labels is not None:
labels = labels[..., 1:].contiguous() # shift labels
labels = F.pad(labels, (0, 1), "constant", IGNORE_INDEX) # pad to the same length as input_ids
labels = slice_input_tensor(labels, dim=-1, padding=True, padding_value=IGNORE_INDEX, group=sp_group)
# Padding position_ids
if position_ids is not None:
position_ids = pad_tensor(position_ids, dim=-1, padding_size=padding_size, padding_value=0)
# Padding attention_mask
if attention_mask is not None:
attn_mask_padding_value = 1 if position_ids is not None else 0
attention_mask = pad_tensor(
attention_mask, dim=-1, padding_size=padding_size, padding_value=attn_mask_padding_value
)
# Padding cu_seqlens
if cu_seqlens is not None:
cu_seqlens_padding_value = cu_seqlens[-1].item() + padding_size
cu_seqlens = pad_tensor(
cu_seqlens, dim=-1, padding_size=padding_size, padding_value=cu_seqlens_padding_value
)
return input_ids, labels, position_ids, attention_mask, cu_seqlens