| import torch |
| import torch.distributed as dist |
| try: |
| from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input |
| FLASH_ATTN_AVAILABLE = True |
| except ImportError: |
| FLASH_ATTN_AVAILABLE = False |
| |
| def index_first_axis(*args, **kwargs): |
| raise NotImplementedError("flash_attn not available") |
| def pad_input(*args, **kwargs): |
| raise NotImplementedError("flash_attn not available") |
| def rearrange(*args, **kwargs): |
| raise NotImplementedError("flash_attn not available") |
| def unpad_input(*args, **kwargs): |
| raise NotImplementedError("flash_attn not available") |
| try: |
| from flash_attn.utils.distributed import all_gather |
| except ImportError: |
| def all_gather(*args, **kwargs): |
| raise NotImplementedError("flash_attn not available") |
|
|
| RING_ATTN_GROUP = None |
|
|
|
|
| def set_ring_attn_group(group): |
| global RING_ATTN_GROUP |
| RING_ATTN_GROUP = group |
|
|
|
|
| def get_ring_attn_group(): |
| return RING_ATTN_GROUP |
|
|
|
|
| def reset_ring_attn_position_ids(start, end, packed_seq_lens): |
| """ |
| Calculate position ids for packed_seq_ids[start:end]. |
| For example, if the packed_seq_lens is [3, 2, 4, 1], start=2, end=8, |
| the position ids will be [2, 0, 1, 0, 1, 2]. |
| |
| Args: |
| start: the start position |
| end: the end position |
| packed_seq_lens: the sequence lengths of packed sequences |
| """ |
| position_ids = torch.zeros((1, end - start), dtype=torch.long, device=torch.cuda.current_device()) |
| offset = 0 |
| for seqlen in packed_seq_lens: |
| seq_start = max(offset, start) |
| seq_end = min(offset + seqlen, end) |
| if seq_start < seq_end: |
| position_ids[0, seq_start - start : seq_end - start] = torch.arange(seq_start - offset, seq_end - offset) |
|
|
| offset += seqlen |
| if offset >= end: |
| break |
| return position_ids |
|
|
|
|
| def update_ring_attn_params(cu_seqlens): |
| """ |
| Calculate the cu_seqlens for the current forward pass and pass the value to |
| the substituted ring_flash_attn. |
| |
| Note that total_seq_len may be larger than the sum of packed_seq_lens because of padding. |
| """ |
| assert RING_ATTN_GROUP is not None |
|
|
| from ring_flash_attn import update_ring_flash_attn_params |
|
|
| update_ring_flash_attn_params(cu_seqlens, RING_ATTN_GROUP) |
|
|
|
|
| def get_tensor_in_current_ring_attn_rank(tensors: list[torch.Tensor] | torch.Tensor, ring_attn_group, pad_id): |
| """ |
| Deal with padding and slice the tensor to current ring_attn_rank. |
| Args: |
| tensors: Each tensor shaped (batch, seqlen) or (1, total_seqs) |
| ring_attn_group: Ring attention group |
| pad_id: Padding id |
| Returns: |
| Processed tensor |
| """ |
| if isinstance(tensors, torch.Tensor): |
| tensors = [tensors] |
| ring_attn_rank = dist.get_rank(group=ring_attn_group) |
| ring_attn_size = dist.get_world_size(group=ring_attn_group) |
| seqlen = tensors[0].shape[-1] |
| total_seq_len = tensors[0].numel() |
| ring_attn_pad_len = (ring_attn_size - seqlen % ring_attn_size) % ring_attn_size |
| output_tensors = [] |
| for tensor in tensors: |
| if tensor.numel() != total_seq_len: |
| raise ValueError(f"tensor.numel() {tensor.numel()} != total_seq_len {total_seq_len}") |
| tensor = torch.nn.functional.pad(tensor, (0, ring_attn_pad_len), value=pad_id) |
| local_seq_len = tensor.numel() // ring_attn_size |
| start, end = ring_attn_rank * local_seq_len, (ring_attn_rank + 1) * local_seq_len |
| tensor = tensor[:, start:end] |
| output_tensors.append(tensor) |
| if len(output_tensors) == 1: |
| output_tensors = output_tensors[0] |
| return output_tensors, ring_attn_pad_len |
|
|
|
|
| def unpad_and_slice_tensor(sequences, attention_mask, ring_attn_group): |
| """ |
| Unpad and slice tensor for distributed training with ring attention. |
| |
| This function performs several operations: |
| 1. Removes padding, unpads sequences from (batch, seqlen) to (1, total_seqs) |
| 2. Adapts to ring_attn_group, pads sequences to be divisible by ring_attn_group |
| 3. Slices the sequences for the current ring_attn_rank |
| |
| Example: |
| >>> # Input sequences shape: (batch=2, seqlen=4) |
| >>> sequences = [[1, 2, 3, 0], [4, 5, 0, 0]] # 0 is padding |
| >>> attention_mask = [[1, 1, 1, 0], [1, 1, 0, 0]] |
| >>> # After unpad: |
| >>> # sequences: [1, 2, 3, 4, 5] # shape (1, total_seqs=5) |
| >>> # If ring_attn_group size is 2, it will pad to length 6 |
| >>> # Then slice for current rank (e.g., rank 0 gets [1,2,3], rank 1 gets [4,5,0]) |
| |
| Args: |
| sequences: Input sequences tensor of shape (batch, seqlen) |
| attention_mask: Attention mask tensor for the sequences |
| ring_attn_group: Ring attention group for distributed processing |
| |
| Returns: |
| tuple: Processed sequences and related tensors for ring attention |
| """ |
| rolled_sequences = torch.roll(sequences, shifts=-1, dims=1) |
| sequences, indices, cu_seqlens, _, _ = unpad_input(sequences.unsqueeze(-1), attention_mask) |
| sequences = sequences.transpose(0, 1) |
| rolled_sequences = index_first_axis( |
| rearrange(rolled_sequences.unsqueeze(-1), "b s ... -> (b s) ..."), indices |
| ).transpose( |
| 0, 1 |
| ) |
| position_ids = torch.clip(torch.cumsum(attention_mask, dim=-1) - 1, min=0, max=None) |
| position_ids = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose( |
| 0, 1 |
| ) |
| ring_attn_pad_len = 0 |
| if ring_attn_group is not None: |
| (sequences, position_ids, rolled_sequences), ring_attn_pad_len = get_tensor_in_current_ring_attn_rank( |
| [sequences, position_ids, rolled_sequences], ring_attn_group, 0 |
| ) |
| cu_seqlens[-1] += ring_attn_pad_len |
| update_ring_attn_params(cu_seqlens) |
| return sequences, position_ids, rolled_sequences, ring_attn_pad_len, indices |
|
|
|
|
| def gather_and_pad_tensor(tensor, ring_attn_group, ring_attn_pad_len, indices, batch, seqlen): |
| """ |
| Gather and pad tensor data (such as logits, log_probs, etc.). |
| |
| Example: |
| >>> # Input tensor from each rank (shape: (1, local_seq_len)) |
| >>> # Rank 0: [1, 2, 3] |
| >>> # Rank 1: [4, 5, 0] # 0 is padding |
| >>> # After all_gather: |
| >>> # tensor: [1, 2, 3, 4, 5, 0] # shape (1, total_seqs=6) |
| >>> # After removing padding (ring_attn_pad_len=1): |
| >>> # tensor: [1, 2, 3, 4, 5] # shape (1, total_seqs=5) |
| >>> # After pad_input with original indices: |
| >>> # tensor: [[1, 2, 3, 0], [4, 5, 0, 0]] # shape (batch=2, seqlen=4) |
| |
| Args: |
| tensor: Input tensor, can be logits, log_probs, etc. |
| ring_attn_group: Ring attention group |
| ring_attn_pad_len: Padding length |
| indices: Indices |
| batch: Batch size |
| seqlen: Sequence length |
| |
| Returns: |
| Padded tensor |
| """ |
| if ring_attn_group is not None: |
| tensor = all_gather(tensor.transpose(0, 1), ring_attn_group).transpose(0, 1) |
| if ring_attn_pad_len > 0: |
| tensor = tensor[:, :-ring_attn_pad_len] |
| tensor = pad_input(tensor.transpose(0, 1), indices, batch, seqlen).squeeze(-1) |
| return tensor |
|
|