| |
|
|
| import torch |
|
|
| import causal_conv1d_cuda |
|
|
|
|
| LIBRARY_NAME = "DaoAILab" |
|
|
|
|
| @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_fwd_cpp", mutates_args={"out", "final_states_out"}) |
| def _causal_conv1d_fwd_cpp( |
| x: torch.Tensor, |
| weight: torch.Tensor, |
| bias: torch.Tensor | None, |
| seq_idx: torch.Tensor | None, |
| initial_states: torch.Tensor | None, |
| out: torch.Tensor, |
| final_states_out: torch.Tensor | None, |
| silu_activation: bool, |
| ) -> None: |
| causal_conv1d_cuda.causal_conv1d_fwd( |
| x, |
| weight, |
| bias, |
| seq_idx, |
| initial_states, |
| out, |
| final_states_out, |
| silu_activation, |
| ) |
|
|
|
|
| @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_bwd_cpp", mutates_args={ |
| "dfinal_states", |
| "dx", |
| "dweight", |
| "dbias", |
| "dinitial_states", |
| }) |
| def _causal_conv1d_bwd_cpp( |
| x: torch.Tensor, |
| weight: torch.Tensor, |
| bias: torch.Tensor | None, |
| dout: torch.Tensor, |
| seq_idx: torch.Tensor | None, |
| initial_states: torch.Tensor | None, |
| dfinal_states: torch.Tensor | None, |
| dx: torch.Tensor, |
| dweight: torch.Tensor, |
| dbias: torch.Tensor | None, |
| dinitial_states: torch.Tensor, |
| silu_activation: bool, |
| ) -> None: |
| causal_conv1d_cuda.causal_conv1d_bwd( |
| x, |
| weight, |
| bias, |
| dout, |
| seq_idx, |
| initial_states, |
| dfinal_states, |
| dx, |
| dweight, |
| dbias, |
| dinitial_states, |
| silu_activation, |
| ) |
|
|
|
|
| @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_update_cpp", mutates_args={"out", "conv_state"}) |
| def _causal_conv1d_update_cpp( |
| x: torch.Tensor, |
| conv_state: torch.Tensor, |
| weight: torch.Tensor, |
| bias: torch.Tensor | None, |
| out: torch.Tensor, |
| silu_activation: bool, |
| cache_seqlens: torch.Tensor | None, |
| conv_state_indices: torch.Tensor | None, |
| ) -> None: |
| causal_conv1d_cuda.causal_conv1d_update( |
| x, |
| conv_state, |
| weight, |
| bias, |
| out, |
| silu_activation, |
| cache_seqlens, |
| conv_state_indices |
| ) |
|
|
|
|
| def causal_conv1d_fwd_function( |
| x: torch.Tensor, |
| weight: torch.Tensor, |
| bias: torch.Tensor | None, |
| seq_idx: torch.Tensor | None, |
| initial_states: torch.Tensor | None, |
| final_states_out: torch.Tensor | None, |
| silu_activation: bool, |
| ) -> torch.Tensor: |
| out = torch.empty_like(x) |
| _causal_conv1d_fwd_cpp( |
| x=x, |
| weight=weight, |
| bias=bias, |
| seq_idx=seq_idx, |
| initial_states=initial_states, |
| out=out, |
| final_states_out=final_states_out, |
| silu_activation=silu_activation, |
| ) |
| return out |
|
|
|
|
| def causal_conv1d_bwd_function( |
| x: torch.Tensor, |
| weight: torch.Tensor, |
| bias: torch.Tensor | None, |
| dout: torch.Tensor, |
| seq_idx: torch.Tensor | None, |
| initial_states: torch.Tensor | None, |
| dfinal_states: torch.Tensor | None, |
| dx: torch.Tensor | None, |
| return_dinitial_states: torch.Tensor, |
| silu_activation: bool, |
| ) -> tuple[torch.Tensor | None]: |
| batch_size, dim = x.size()[:2] |
| width = weight.size(-1) |
|
|
| if dx is None: |
| dx = torch.empty_like(x) |
| dweight = torch.zeros_like(weight, dtype=torch.float32) |
| dbias = None |
| if bias is not None: |
| dbias = torch.zeros_like(bias, dtype=torch.float32) |
| dinitial_states = None |
| if return_dinitial_states: |
| dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) |
|
|
| _causal_conv1d_bwd_cpp( |
| x=x, |
| weight=weight, |
| bias=bias, |
| dout=dout, |
| seq_idx=seq_idx, |
| initial_states=initial_states, |
| dfinal_states=dfinal_states, |
| dx=dx, |
| dweight=dweight, |
| dbias=dbias, |
| dinitial_states=dinitial_states, |
| silu_activation=silu_activation, |
| ) |
|
|
| dweight = dweight.type_as(weight) |
| if dbias is not None: |
| dbias = dbias.type_as(bias) |
| return dx, dweight, dbias, dinitial_states |
|
|
|
|
| def causal_conv1d_update_function( |
| x: torch.Tensor, |
| conv_state: torch.Tensor, |
| weight: torch.Tensor, |
| bias: torch.Tensor | None, |
| silu_activation: bool, |
| cache_seqlens: torch.Tensor | None, |
| conv_state_indices: torch.Tensor | None, |
| ) -> torch.Tensor: |
| out = torch.empty_like(x) |
| _causal_conv1d_update_cpp( |
| x=x, |
| conv_state=conv_state, |
| weight=weight, |
| bias=bias, |
| out=out, |
| silu_activation=silu_activation, |
| cache_seqlens=cache_seqlens, |
| conv_state_indices=conv_state_indices, |
| ) |
| return out |
|
|