File size: 2,539 Bytes
d02d576 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | from typing import Optional
import torch
# mamba
def causal_conv1d_fwd(
x: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool,
pad_slot_id: int,
):
torch.ops.sgl_kernel.causal_conv1d_fwd(
x,
weight,
bias_,
conv_states,
query_start_loc,
cache_indices,
has_initial_state,
silu_activation,
pad_slot_id,
)
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int,
):
torch.ops.sgl_kernel.causal_conv1d_update(
x,
conv_state,
weight,
bias_,
silu_activation,
cache_seqlens,
conv_state_indices,
pad_slot_id,
)
def causal_conv1d_fn_cpu(
mixed_qkv_transposed,
conv_weights,
bias,
activation,
conv_states,
has_initial_state,
cache_indices,
query_start_loc,
seq_lens_cpu,
):
return torch.ops.sgl_kernel.causal_conv1d_fwd_cpu(
mixed_qkv_transposed,
conv_weights,
bias,
conv_states,
query_start_loc,
cache_indices,
has_initial_state,
activation == "silu",
-1,
True,
)
def causal_conv1d_update_cpu(
mixed_qkv, conv_states, conv_weights, bias, activation, conv_state_indices
):
return torch.ops.sgl_kernel.causal_conv1d_update_cpu(
mixed_qkv,
conv_states,
conv_weights,
bias,
activation == "silu",
None,
conv_state_indices,
-1,
True,
)
def chunk_gated_delta_rule_cpu(
q,
k,
v,
g,
beta,
initial_state,
cu_seqlens,
head_first,
use_qk_l2norm_in_kernel,
):
core_attn_out, last_recurrent_state = (
torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu(
q,
k,
v,
g,
beta,
initial_state,
True, # output_final_state
cu_seqlens,
head_first,
use_qk_l2norm_in_kernel,
)
)
h = None # Todo: add return h support
return core_attn_out, last_recurrent_state, h
|