|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
|
|
import torch |
|
|
from einops import rearrange, repeat |
|
|
|
|
|
from ..ops.utils.index import prepare_cu_seqlens_from_mask, prepare_lens_from_mask |
|
|
from ..utils import tensor_cache |
|
|
|
|
|
|
|
|
class IndexFirstAxis(torch.autograd.Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, x, indices): |
|
|
ctx.save_for_backward(indices) |
|
|
assert x.ndim >= 2 |
|
|
ctx.first_axis_dim, other_shape = x.shape[0], x.shape[1:] |
|
|
second_dim = other_shape.numel() |
|
|
|
|
|
|
|
|
return torch.gather( |
|
|
rearrange(x, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) |
|
|
).reshape(-1, *other_shape) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, do): |
|
|
(indices,) = ctx.saved_tensors |
|
|
assert do.ndim >= 2 |
|
|
other_shape = do.shape[1:] |
|
|
do = rearrange(do, "b ... -> b (...)") |
|
|
dx = torch.zeros( |
|
|
[ctx.first_axis_dim, do.shape[1]], |
|
|
device=do.device, |
|
|
dtype=do.dtype, |
|
|
) |
|
|
|
|
|
|
|
|
dx.scatter_(0, repeat(indices, "z -> z d", d=do.shape[1]), do) |
|
|
return dx.reshape(ctx.first_axis_dim, *other_shape), None |
|
|
|
|
|
|
|
|
index_first_axis = IndexFirstAxis.apply |
|
|
|
|
|
|
|
|
class IndexPutFirstAxis(torch.autograd.Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, x, indices, first_axis_dim): |
|
|
ctx.save_for_backward(indices) |
|
|
assert indices.ndim == 1 |
|
|
assert x.ndim >= 2 |
|
|
y = torch.zeros(first_axis_dim, *x.shape[1:], device=x.device, dtype=x.dtype) |
|
|
|
|
|
y[indices] = x |
|
|
|
|
|
return y |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, do): |
|
|
(indices,) = ctx.saved_tensors |
|
|
|
|
|
dx = do[indices] |
|
|
|
|
|
return dx, None, None |
|
|
|
|
|
|
|
|
index_put_first_axis = IndexPutFirstAxis.apply |
|
|
|
|
|
|
|
|
@tensor_cache |
|
|
def get_unpad_data( |
|
|
attention_mask: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]: |
|
|
""" |
|
|
Retrieves indexing data required to repad unpadded (ragged) tensors. |
|
|
|
|
|
Args: |
|
|
attention_mask (`torch.Tensor`): |
|
|
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. |
|
|
|
|
|
Return: |
|
|
indices (`torch.Tensor`): |
|
|
The indices of non-masked tokens from the flattened input sequence. |
|
|
cu_seqlens (`torch.Tensor`): |
|
|
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. |
|
|
`cu_seqlens` shape is [batch_size + 1]. |
|
|
max_seqlen_in_batch (`int`): |
|
|
Maximum sequence length in batch. |
|
|
""" |
|
|
lens = prepare_lens_from_mask(attention_mask) |
|
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
|
|
max_seqlen_in_batch = lens.max().item() |
|
|
cu_seqlens = prepare_cu_seqlens_from_mask(attention_mask) |
|
|
return indices, cu_seqlens, max_seqlen_in_batch |
|
|
|
|
|
|
|
|
def unpad_input( |
|
|
q: torch.Tensor, |
|
|
states: Tuple[torch.Tensor], |
|
|
attention_mask: torch.Tensor, |
|
|
q_len: int, |
|
|
keepdim: bool = False, |
|
|
): |
|
|
""" |
|
|
Unpads query, key, and values tensors, using a single dimension for all tokens |
|
|
even though they belong to different batches. |
|
|
|
|
|
|
|
|
Arguments: |
|
|
q (`torch.Tensor`): |
|
|
Query state with padding. Shape: [batch_size, q_len, ...]. |
|
|
states (`Tuple[torch.Tensor]`): |
|
|
Attention state with padding. Shape: [batch_size, seq_len, ...]. |
|
|
attention_mask (`torch.Tensor`): |
|
|
Boolean or int tensor of shape [batch_size, sequence_length], 1 means valid and 0 means not valid. |
|
|
q_len (`int`): |
|
|
Target length. |
|
|
keepdim (`bool`): |
|
|
Whether to keep the batch dimension. Default: `False`. |
|
|
|
|
|
Return: |
|
|
q (`torch.Tensor`): |
|
|
Query state without padding. |
|
|
Shape: [1, total_target_length, ...] if `keepdim=True` else [total_target_length, ...]. |
|
|
states (`Tuple[torch.Tensor]`): |
|
|
Attention state without padding. |
|
|
Shape: [1, total_source_length, ...] if `keepdim=True` else [total_source_length, ...]. |
|
|
indices_q (`torch.Tensor`): |
|
|
The indices of non-masked tokens from the flattened input target sequence. |
|
|
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): |
|
|
The cumulative sequence lengths for the target (query) and source (key, value), |
|
|
used to index into ragged (unpadded) tensors. |
|
|
`cu_seqlens` shape is [batch_size + 1]. |
|
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): |
|
|
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence |
|
|
i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). |
|
|
""" |
|
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) |
|
|
batch_size, seq_len, *_ = states[0].shape |
|
|
|
|
|
state = tuple( |
|
|
index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k) |
|
|
for s in states |
|
|
) |
|
|
|
|
|
if q_len == seq_len: |
|
|
q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k) |
|
|
cu_seqlens_q = cu_seqlens_k |
|
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k |
|
|
indices_q = indices_k |
|
|
elif q_len == 1: |
|
|
max_seqlen_in_batch_q = 1 |
|
|
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) |
|
|
indices_q = cu_seqlens_q[:-1] |
|
|
q = q.squeeze(1) |
|
|
else: |
|
|
raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)") |
|
|
|
|
|
if keepdim: |
|
|
q = q.unsqueeze(0) |
|
|
state = tuple(s.unsqueeze(0) for s in state) |
|
|
|
|
|
return ( |
|
|
q, |
|
|
state, |
|
|
indices_q, |
|
|
(cu_seqlens_q, cu_seqlens_k), |
|
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
|
|
) |
|
|
|
|
|
|
|
|
def pad_input( |
|
|
hidden_states: torch.Tensor, |
|
|
indices: torch.LongTensor, |
|
|
batch_size: int, |
|
|
seq_len: int, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
hidden_states ([total_tokens, ...]): |
|
|
where total_tokens denotes the number of tokens in selected in attention_mask. |
|
|
indices ([total_tokens]): |
|
|
the indices that represent the non-masked tokens of the original padded input sequence. |
|
|
batch_size (int): |
|
|
batch_size size for the padded sequence. |
|
|
seq_len (int): |
|
|
maximum sequence length for the padded sequence. |
|
|
|
|
|
Return: |
|
|
hidden_states of shape [batch_size, seq_len, ...] |
|
|
""" |
|
|
output = index_put_first_axis(hidden_states, indices, batch_size * seq_len) |
|
|
return rearrange(output, "(b s) ... -> b s ...", b=batch_size) |
|
|
|