gated_deltaproduct / fla3 /ops /gla /fused_recurrent.py
msj19's picture
Add files using upload-large-folder tool
0a2b89e verified
# -*- coding: utf-8 -*-
# Copyright (c) 2024, Songlin Yang, Yu Zhang
from typing import Optional, Tuple
import torch
from fla.ops.common.fused_recurrent import fused_recurrent
def fused_recurrent_gla(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
gk: Optional[torch.Tensor] = None,
gv: Optional[torch.Tensor] = None,
scale: Optional[int] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
reverse: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]`.
gk (torch.Tensor):
Forget gates of shape `[B, T, H, K]`.
gv (torch.Tensor):
Forget gates of shape `[B, T, H, V]` applied to values.
scale (Optional[int]):
Scale factor for the attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
reverse (Optional[bool]):
If `True`, process the state passing in reverse order. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gla import fused_recurrent_gla
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, device='cuda')
>>> k = torch.randn(B, T, H, K, device='cuda')
>>> v = torch.randn(B, T, H, V, device='cuda')
>>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, device='cuda')
>>> o, ht = fused_recurrent_gla(
q, k, v, g,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = fused_recurrent_gla(
q, k, v, g,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
>>> assert o.allclose(o_var.view(o.shape))
"""
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1] ** -0.5
o, final_state = fused_recurrent(
q=q,
k=k,
v=v,
g=None,
gk=gk,
gv=gv,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
reverse=reverse,
cu_seqlens=cu_seqlens,
)
return o, final_state