diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..a10d9394f1035d070b28acafacb5479e20508d60 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6eb0fdb8827538d27d0822e22dd968059657aafdd8dca77b99d606e0026ae43b +size 80694456 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fa868509249a5659e14a7821104051efa3f250ad --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_1b44a8e +ops = torch.ops._causal_conv1d_1b44a8e + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..8fde84ac489ab3270264c00c60e3d562e78ec103 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78531cef5f05968a528ae8bc7a5a348b2abad1b180ac90142dd7df2491cef608 +size 107169824 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fa868509249a5659e14a7821104051efa3f250ad --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_1b44a8e +ops = torch.ops._causal_conv1d_1b44a8e + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..6f01490f3df550bd3906e70f05eb480e0cbf9fba --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8746e8c1e94e2022fe638316ba9cf89489d45d0d92047cafe54e554297a2c701 +size 64618464 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fa868509249a5659e14a7821104051efa3f250ad --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_1b44a8e +ops = torch.ops._causal_conv1d_1b44a8e + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..5ff05f9899369b601aa80c960b3f025b256bf9c6 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:49a73bdc1f6d9a32c2e107610f5ba22c2ca054a3efc1237a8291118af3191e7b +size 80684768 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fa868509249a5659e14a7821104051efa3f250ad --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_1b44a8e +ops = torch.ops._causal_conv1d_1b44a8e + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..bc8b1dca66907aa286d15ec3bebdb510fc4f0a59 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43ea19b486dc11d1eb780e7c1c4944ad27d27713ab41b8824b14add98c5eb645 +size 107168432 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fa868509249a5659e14a7821104051efa3f250ad --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_1b44a8e +ops = torch.ops._causal_conv1d_1b44a8e + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..35118540138a538805897101f81b6375e10c87b3 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a95ffd016cdfe6f401e0e495b0083a12395ffe82b3888c510b86f4a58dfe068 +size 115140584 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fa868509249a5659e14a7821104051efa3f250ad --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_1b44a8e +ops = torch.ops._causal_conv1d_1b44a8e + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..4b8e537aec32c521a76d7f422abf321ba932f4ae --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1652a695f38a12463ece0e84007e34575177c678c2432e97e1510064ea6b627a +size 80684856 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fa868509249a5659e14a7821104051efa3f250ad --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_1b44a8e +ops = torch.ops._causal_conv1d_1b44a8e + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..27e888e460640fd2abf04ca37f69d56f0322a3f0 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae14d37443fbdadb3ffb7b5a4b6d46c01d7b17a965938ad9f6314dfa575c58de +size 107172616 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fa868509249a5659e14a7821104051efa3f250ad --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_1b44a8e +ops = torch.ops._causal_conv1d_1b44a8e + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..4df07e840976b195056bfc406e26080680d4f083 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0cc8ad33b1164913858ea1cffce2f21f6f74d508da4cdd07de7d54aeb1c28ecc +size 64613056 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fa868509249a5659e14a7821104051efa3f250ad --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_1b44a8e +ops = torch.ops._causal_conv1d_1b44a8e + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file