org_gdn_1B / fla2 /ops /simple_gla /recurrent_fuse.py
msj19's picture
Add files using upload-large-folder tool
b68ddd6 verified
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple, Optional
import torch
from fla.ops.common.fused_recurrent import fused_recurrent
def fused_recurrent_simple_gla(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
scale: Optional[float] = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
reverse: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if scale is None:
scale = q.shape[-1] ** -0.5
o, final_state = fused_recurrent(q, k, v, g, None, None, scale, initial_state, output_final_state, reverse)
return o, final_state