Lekr0's picture
Add files using upload-large-folder tool
d02d576 verified
from typing import Optional
import torch
# mamba
def causal_conv1d_fwd(
x: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool,
pad_slot_id: int,
):
torch.ops.sgl_kernel.causal_conv1d_fwd(
x,
weight,
bias_,
conv_states,
query_start_loc,
cache_indices,
has_initial_state,
silu_activation,
pad_slot_id,
)
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int,
):
torch.ops.sgl_kernel.causal_conv1d_update(
x,
conv_state,
weight,
bias_,
silu_activation,
cache_seqlens,
conv_state_indices,
pad_slot_id,
)
def causal_conv1d_fn_cpu(
mixed_qkv_transposed,
conv_weights,
bias,
activation,
conv_states,
has_initial_state,
cache_indices,
query_start_loc,
seq_lens_cpu,
):
return torch.ops.sgl_kernel.causal_conv1d_fwd_cpu(
mixed_qkv_transposed,
conv_weights,
bias,
conv_states,
query_start_loc,
cache_indices,
has_initial_state,
activation == "silu",
-1,
True,
)
def causal_conv1d_update_cpu(
mixed_qkv, conv_states, conv_weights, bias, activation, conv_state_indices
):
return torch.ops.sgl_kernel.causal_conv1d_update_cpu(
mixed_qkv,
conv_states,
conv_weights,
bias,
activation == "silu",
None,
conv_state_indices,
-1,
True,
)
def chunk_gated_delta_rule_cpu(
q,
k,
v,
g,
beta,
initial_state,
cu_seqlens,
head_first,
use_qk_l2norm_in_kernel,
):
core_attn_out, last_recurrent_state = (
torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu(
q,
k,
v,
g,
beta,
initial_state,
True, # output_final_state
cu_seqlens,
head_first,
use_qk_l2norm_in_kernel,
)
)
h = None # Todo: add return h support
return core_attn_out, last_recurrent_state, h