build-tools / causal_conv1d /cpp_functions.py
salmankhanpm's picture
Add files using upload-large-folder tool
dc9bb20 verified
# Copyright (c) 2024, Tri Dao.
import torch
import causal_conv1d_cuda
LIBRARY_NAME = "DaoAILab"
@torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_fwd_cpp", mutates_args={"out", "final_states_out"})
def _causal_conv1d_fwd_cpp(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
seq_idx: torch.Tensor | None,
initial_states: torch.Tensor | None,
out: torch.Tensor,
final_states_out: torch.Tensor | None,
silu_activation: bool,
) -> None:
causal_conv1d_cuda.causal_conv1d_fwd(
x,
weight,
bias,
seq_idx,
initial_states,
out,
final_states_out,
silu_activation,
)
@torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_bwd_cpp", mutates_args={
"dfinal_states",
"dx",
"dweight",
"dbias",
"dinitial_states",
})
def _causal_conv1d_bwd_cpp(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
dout: torch.Tensor,
seq_idx: torch.Tensor | None,
initial_states: torch.Tensor | None,
dfinal_states: torch.Tensor | None,
dx: torch.Tensor,
dweight: torch.Tensor,
dbias: torch.Tensor | None,
dinitial_states: torch.Tensor,
silu_activation: bool,
) -> None:
causal_conv1d_cuda.causal_conv1d_bwd(
x,
weight,
bias,
dout,
seq_idx,
initial_states,
dfinal_states,
dx,
dweight,
dbias,
dinitial_states,
silu_activation,
)
@torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_update_cpp", mutates_args={"out", "conv_state"})
def _causal_conv1d_update_cpp(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
out: torch.Tensor,
silu_activation: bool,
cache_seqlens: torch.Tensor | None,
conv_state_indices: torch.Tensor | None,
) -> None:
causal_conv1d_cuda.causal_conv1d_update(
x,
conv_state,
weight,
bias,
out,
silu_activation,
cache_seqlens,
conv_state_indices
)
def causal_conv1d_fwd_function(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
seq_idx: torch.Tensor | None,
initial_states: torch.Tensor | None,
final_states_out: torch.Tensor | None,
silu_activation: bool,
) -> torch.Tensor:
out = torch.empty_like(x)
_causal_conv1d_fwd_cpp(
x=x,
weight=weight,
bias=bias,
seq_idx=seq_idx,
initial_states=initial_states,
out=out,
final_states_out=final_states_out,
silu_activation=silu_activation,
)
return out
def causal_conv1d_bwd_function(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
dout: torch.Tensor,
seq_idx: torch.Tensor | None,
initial_states: torch.Tensor | None,
dfinal_states: torch.Tensor | None,
dx: torch.Tensor | None,
return_dinitial_states: torch.Tensor,
silu_activation: bool,
) -> tuple[torch.Tensor | None]:
batch_size, dim = x.size()[:2]
width = weight.size(-1)
if dx is None:
dx = torch.empty_like(x)
dweight = torch.zeros_like(weight, dtype=torch.float32)
dbias = None
if bias is not None:
dbias = torch.zeros_like(bias, dtype=torch.float32)
dinitial_states = None
if return_dinitial_states:
dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
_causal_conv1d_bwd_cpp(
x=x,
weight=weight,
bias=bias,
dout=dout,
seq_idx=seq_idx,
initial_states=initial_states,
dfinal_states=dfinal_states,
dx=dx,
dweight=dweight,
dbias=dbias,
dinitial_states=dinitial_states,
silu_activation=silu_activation,
)
dweight = dweight.type_as(weight)
if dbias is not None:
dbias = dbias.type_as(bias)
return dx, dweight, dbias, dinitial_states
def causal_conv1d_update_function(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
silu_activation: bool,
cache_seqlens: torch.Tensor | None,
conv_state_indices: torch.Tensor | None,
) -> torch.Tensor:
out = torch.empty_like(x)
_causal_conv1d_update_cpp(
x=x,
conv_state=conv_state,
weight=weight,
bias=bias,
out=out,
silu_activation=silu_activation,
cache_seqlens=cache_seqlens,
conv_state_indices=conv_state_indices,
)
return out