| | |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| |
|
| |
|
| | def unpad_input(hidden_states, attention_mask, unused_mask=None): |
| | """ |
| | Arguments: |
| | hidden_states: (batch, seqlen, ...) |
| | attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. |
| | unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. |
| | Return: |
| | hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. |
| | indices: (total_nnz), the indices of masked tokens from the flattened input sequence. |
| | cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. |
| | max_seqlen_in_batch: int |
| | seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. |
| | """ |
| | all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask |
| | seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) |
| | used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| | indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() |
| | max_seqlen_in_batch = seqlens_in_batch.max().item() |
| | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| | |
| | |
| | |
| | |
| | return ( |
| | rearrange(hidden_states, "b s ... -> (b s) ...")[indices], |
| | indices, |
| | cu_seqlens, |
| | max_seqlen_in_batch, |
| | used_seqlens_in_batch, |
| | ) |
| |
|
| |
|
| | def pad_input(hidden_states, indices, batch, seqlen): |
| | """ |
| | Arguments: |
| | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. |
| | indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. |
| | batch: int, batch size for the padded sequence. |
| | seqlen: int, maximum sequence length for the padded sequence. |
| | Return: |
| | hidden_states: (batch, seqlen, ...) |
| | """ |
| | dim = hidden_states.shape[1:] |
| | output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) |
| | output[indices] = hidden_states |
| | return rearrange(output, "(b s) ... -> b s ...", b=batch) |
| |
|