diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d771f1504640a5756486abcdf1bcf21e45e92f0 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b4fa56b8116b17718afc9b62e3ee61a10fc8f85 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c28ccabf5d15b099f889bcc9c674704f2d2a35e Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9a18fa8106e824330c12089d84cc6c7758c3fd7 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..405282a0b844ab76a98031c22d5f064bf48592c3 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d6fdbaf71bdf580e7d5816f235b28ed46359d79c --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9875e2f9570a0bca9725184062f07e0f904e87a514d25175cda87e0c95a8666a +size 64503976 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..92fecce393e361dda107f47aa06cd5df2924281d --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_db03e28_dirty +ops = torch.ops._causal_conv1d_db03e28_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1551eb5b53dbc3c61977f52179b70a089fd1df28 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6239878edc97c5a6d0759d26dcb32e98688bd1b Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb269fc8c92ec2ee2ee74e399e351e3a2e068bf1 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da4fddd46d1b1dc938f04afa636a2d90c3fc3391 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..763ed3ed177891c42196ff33a9a3804929f1d2e7 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..74940a8b1b7e690a9ec507e8e3c3095bbf180a27 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7c97c4b57e74c91956eca7b812f88fa647a2c3676cc52b2bf585f79cc73d9ea +size 64213584 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..92fecce393e361dda107f47aa06cd5df2924281d --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_db03e28_dirty +ops = torch.ops._causal_conv1d_db03e28_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e385c426b54adbe60e6f705951227631db9cfd18 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d38e35200729fde9a6a49b934268706adf8f489 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1083c3d0e18bb70e8900b92544e5faa660765504 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e98a981b1953e8a0c05eb5e9ae9a371268e33adf Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6a734a841d689e2f2d71ae59cee104c82aa09f5 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..3f3829da2340c302763d70d8470582cf239f6467 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d631b6843ce08b07172d0d91c3eb55c1d4d5b45839337e1eeeeb430054f97ce +size 102460944 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..92fecce393e361dda107f47aa06cd5df2924281d --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_db03e28_dirty +ops = torch.ops._causal_conv1d_db03e28_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19b05560416a36434d6df3ca202a01344f56e5b8 Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caae805f8c9f085caad051c4bd6c110f4f8a9c02 Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..826f93b5fd568a8a03bb8cbed2994c7a9e7f15ca Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fffcd669af4726e9d0adf64fa8d1fe6159f2fe13 Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbf994dd52d0457bdb89dfd81378f9c6a9ed0188 Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..87bee4c541294833d811aa201e3b5d10aaa2a479 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd6b0804c4998af6b6dd9f64dfdabe98ac3db66fe78f6e5ea1a4db32d667a431 +size 64213808 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..92fecce393e361dda107f47aa06cd5df2924281d --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_db03e28_dirty +ops = torch.ops._causal_conv1d_db03e28_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b5eb0a6b50d00f52d6da06563d1d28fc1e564b8 Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5cf938d3e1eeebbb152880a32435e91b9ca5b5f Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2610dba23471323009b19ced6f9e2e669ee5b7e9 Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39d2a0b66b76159be8e52cfcf6c75af7826f82fa Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5eaa70fc7ef0ce097eab4af9b3aac82a72eba7a Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..ba9129f6115ce77aedda557ae5da82241b51ce9b --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32eb5c27a38a2f9cad4076e3d3e47b1cfe1473604b7867eb87259872c49fe64a +size 102465272 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..92fecce393e361dda107f47aa06cd5df2924281d --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_db03e28_dirty +ops = torch.ops._causal_conv1d_db03e28_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79b116d0a0cb0e8aa4a404292f8bc2c14dcbd25e Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa7a049efac4ba97aac5288678ce588364d5868b Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a730525ba8d9fe9b2f313196e086388d4c3f1d6 Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c2596be71aa4fd852dc7077ec79a843998e862f Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..045a8072a22dc785881ebdcc4173848ef8259a3d Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..481a6e6b7c040bdd092adfc01329d856e55af62f --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d95e7ebb7ad5f7881f5876269e47f5b28dd22b4de7c2d598476e037f67937c6 +size 110109736 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..92fecce393e361dda107f47aa06cd5df2924281d --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_db03e28_dirty +ops = torch.ops._causal_conv1d_db03e28_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out