| from typing import Optional |
|
|
| import torch |
|
|
|
|
| |
| def causal_conv1d_fwd( |
| x: torch.Tensor, |
| weight: torch.Tensor, |
| bias_: Optional[torch.Tensor], |
| conv_states: Optional[torch.Tensor], |
| query_start_loc: Optional[torch.Tensor], |
| cache_indices: Optional[torch.Tensor], |
| has_initial_state: Optional[torch.Tensor], |
| silu_activation: bool, |
| pad_slot_id: int, |
| ): |
| torch.ops.sgl_kernel.causal_conv1d_fwd( |
| x, |
| weight, |
| bias_, |
| conv_states, |
| query_start_loc, |
| cache_indices, |
| has_initial_state, |
| silu_activation, |
| pad_slot_id, |
| ) |
|
|
|
|
| def causal_conv1d_update( |
| x: torch.Tensor, |
| conv_state: torch.Tensor, |
| weight: torch.Tensor, |
| bias_: Optional[torch.Tensor], |
| silu_activation: bool, |
| cache_seqlens: Optional[torch.Tensor], |
| conv_state_indices: Optional[torch.Tensor], |
| pad_slot_id: int, |
| ): |
| torch.ops.sgl_kernel.causal_conv1d_update( |
| x, |
| conv_state, |
| weight, |
| bias_, |
| silu_activation, |
| cache_seqlens, |
| conv_state_indices, |
| pad_slot_id, |
| ) |
|
|
|
|
| def causal_conv1d_fn_cpu( |
| mixed_qkv_transposed, |
| conv_weights, |
| bias, |
| activation, |
| conv_states, |
| has_initial_state, |
| cache_indices, |
| query_start_loc, |
| seq_lens_cpu, |
| ): |
| return torch.ops.sgl_kernel.causal_conv1d_fwd_cpu( |
| mixed_qkv_transposed, |
| conv_weights, |
| bias, |
| conv_states, |
| query_start_loc, |
| cache_indices, |
| has_initial_state, |
| activation == "silu", |
| -1, |
| True, |
| ) |
|
|
|
|
| def causal_conv1d_update_cpu( |
| mixed_qkv, conv_states, conv_weights, bias, activation, conv_state_indices |
| ): |
| return torch.ops.sgl_kernel.causal_conv1d_update_cpu( |
| mixed_qkv, |
| conv_states, |
| conv_weights, |
| bias, |
| activation == "silu", |
| None, |
| conv_state_indices, |
| -1, |
| True, |
| ) |
|
|
|
|
| def chunk_gated_delta_rule_cpu( |
| q, |
| k, |
| v, |
| g, |
| beta, |
| initial_state, |
| cu_seqlens, |
| head_first, |
| use_qk_l2norm_in_kernel, |
| ): |
| core_attn_out, last_recurrent_state = ( |
| torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu( |
| q, |
| k, |
| v, |
| g, |
| beta, |
| initial_state, |
| True, |
| cu_seqlens, |
| head_first, |
| use_qk_l2norm_in_kernel, |
| ) |
| ) |
| h = None |
| return core_attn_out, last_recurrent_state, h |
|
|