diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py deleted file mode 100644 index e00d508d004ebe3eafe214d2b1b2ec2a44090d5c..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update -from .causal_conv1d_varlen import causal_conv1d_varlen_states - -__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so deleted file mode 100644 index a10d9394f1035d070b28acafacb5479e20508d60..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6eb0fdb8827538d27d0822e22dd968059657aafdd8dca77b99d606e0026ae43b -size 80694456 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py deleted file mode 100644 index fa868509249a5659e14a7821104051efa3f250ad..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _causal_conv1d_1b44a8e -ops = torch.ops._causal_conv1d_1b44a8e - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py deleted file mode 100644 index d46d56c415e91e16a475a7261e01658b2259d377..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch -import torch.nn.functional as F - -from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function - - -class CausalConv1dFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, - ): - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(2) != 1 and x.stride(1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - if seq_idx is not None: - assert ( - initial_states is None - ), "initial_states must be None if seq_idx is not None" - assert ( - not return_final_states - ), "If seq_idx is not None, we don't return final_states_out" - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - if initial_states is not None and ( - initial_states.stride(2) != 1 and initial_states.stride(1) != 1 - ): - initial_states = initial_states.contiguous() - if return_final_states: - assert ( - x.stride(1) == 1 - ), "Only channel-last layout support returning final_states_out" - if final_states_out is not None: - assert ( - final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 - ) - else: - batch, dim, seqlen = x.shape - width = weight.shape[1] - final_states_out = torch.empty( - batch, width - 1, dim, device=x.device, dtype=x.dtype - ).transpose(1, 2) - else: - final_states_out = None - ctx.activation = activation in ["silu", "swish"] - out = causal_conv1d_fwd_function( - x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation - ) - ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) - ctx.return_final_states = return_final_states - ctx.return_dinitial_states = ( - initial_states is not None and initial_states.requires_grad - ) - return out if not return_final_states else (out, final_states_out) - - @staticmethod - def backward(ctx, dout, *args): - x, weight, bias, seq_idx, initial_states = ctx.saved_tensors - dfinal_states = args[0] if ctx.return_final_states else None - if dout.stride(2) != 1 and dout.stride(1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - # Here we just pass in None and dx will be allocated in the C++ code. - dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( - x, - weight, - bias, - dout, - seq_idx, - initial_states, - dfinal_states, - None, - ctx.return_dinitial_states, - ctx.activation, - ) - return ( - dx, - dweight, - dbias if bias is not None else None, - None, - dinitial_states if initial_states is not None else None, - None, - None, - None, - ) - - -def causal_conv1d_fn( - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - seq_idx: (batch, seqlen) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1), to be written to - activation: either None or "silu" or "swish" - - out: (batch, dim, seqlen) - """ - return CausalConv1dFn.apply( - x, - weight, - bias, - seq_idx, - initial_states, - return_final_states, - final_states_out, - activation, - ) - - -def causal_conv1d_ref( - x, - weight, - bias=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in - ) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return out if not return_final_states else (out, final_states_out) - - -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - out = causal_conv1d_update_function( - x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices - ) - if unsqueeze: - out = out.squeeze(-1) - return out - - -def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert conv_state.shape == (batch, dim, state_len) - assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state.copy_(x_new[:, :, -state_len:]) - else: - width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) - copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py deleted file mode 100644 index 8005af233d5c21b0a58917a8a18045636c2351cb..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import Tensor - -import triton -import triton.language as tl - - -@triton.jit -def _causal_conv1d_varlen_states( - X, - CU_SEQLENS, - STATES, - state_len, - dim, - stride_x_seqlen, stride_x_dim, - stride_states_batch, stride_states_seqlen, stride_states_dim, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr -): - batch_idx = tl.program_id(2) - STATES += batch_idx * stride_states_batch - end_idx = tl.load(CU_SEQLENS + batch_idx + 1) - start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) - rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) - x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, - mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), - other=0) - rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, - x, - mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) - - -def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - BLOCK_M = min(triton.next_power_of_2(state_len), 16) - BLOCK_N = min(triton.next_power_of_2(dim), 256) - grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) - with torch.cuda.device(x.device.index): - _causal_conv1d_varlen_states[grid]( - x, - cu_seqlens, - states, - state_len, - dim, - x.stride(0), x.stride(1), - states.stride(0), states.stride(2), states.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - return states - - -def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - for i in range(batch): - end_idx = cu_seqlens[i + 1] - start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) - states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T - return states diff --git a/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py deleted file mode 100644 index 1ddb9f83ceb4e9f72754fe39340738d47b6aea1b..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch - -from ._ops import ops - -def causal_conv1d_fwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - final_states_out: torch.Tensor | None, - silu_activation: bool, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_fwd( - x=x, - weight=weight, - bias=bias, - seq_idx=seq_idx, - initial_states=initial_states, - out=out, - final_states_out=final_states_out, - silu_activation=silu_activation, - ) - return out - - -def causal_conv1d_bwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - dout: torch.Tensor, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - dfinal_states: torch.Tensor | None, - dx: torch.Tensor | None, - return_dinitial_states: torch.Tensor, - silu_activation: bool, -) -> tuple[torch.Tensor | None]: - batch_size, dim = x.size()[:2] - width = weight.size(-1) - - if dx is None: - dx = torch.empty_like(x) - dweight = torch.zeros_like(weight, dtype=torch.float32) - dbias = None - if bias is not None: - dbias = torch.zeros_like(bias, dtype=torch.float32) - dinitial_states = None - if return_dinitial_states: - dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) - - ops.causal_conv1d_bwd( - x=x, - weight=weight, - bias=bias, - dout=dout, - seq_idx=seq_idx, - initial_states=initial_states, - dfinal_states=dfinal_states, - dx=dx, - dweight=dweight, - dbias=dbias, - dinitial_states=dinitial_states, - silu_activation=silu_activation, - ) - - dweight = dweight.type_as(weight) - if dbias is not None: - dbias = dbias.type_as(bias) - return dx, dweight, dbias, dinitial_states - - -def causal_conv1d_update_function( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - silu_activation: bool, - cache_seqlens: torch.Tensor | None, - conv_state_indices: torch.Tensor | None, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_update( - x=x, - conv_state=conv_state, - weight=weight, - bias=bias, - out=out, - silu_activation=silu_activation, - cache_seqlens=cache_seqlens, - conv_state_indices=conv_state_indices, - ) - return out diff --git a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py deleted file mode 100644 index e00d508d004ebe3eafe214d2b1b2ec2a44090d5c..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update -from .causal_conv1d_varlen import causal_conv1d_varlen_states - -__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so deleted file mode 100644 index 8fde84ac489ab3270264c00c60e3d562e78ec103..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:78531cef5f05968a528ae8bc7a5a348b2abad1b180ac90142dd7df2491cef608 -size 107169824 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py deleted file mode 100644 index fa868509249a5659e14a7821104051efa3f250ad..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _causal_conv1d_1b44a8e -ops = torch.ops._causal_conv1d_1b44a8e - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py deleted file mode 100644 index d46d56c415e91e16a475a7261e01658b2259d377..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch -import torch.nn.functional as F - -from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function - - -class CausalConv1dFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, - ): - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(2) != 1 and x.stride(1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - if seq_idx is not None: - assert ( - initial_states is None - ), "initial_states must be None if seq_idx is not None" - assert ( - not return_final_states - ), "If seq_idx is not None, we don't return final_states_out" - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - if initial_states is not None and ( - initial_states.stride(2) != 1 and initial_states.stride(1) != 1 - ): - initial_states = initial_states.contiguous() - if return_final_states: - assert ( - x.stride(1) == 1 - ), "Only channel-last layout support returning final_states_out" - if final_states_out is not None: - assert ( - final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 - ) - else: - batch, dim, seqlen = x.shape - width = weight.shape[1] - final_states_out = torch.empty( - batch, width - 1, dim, device=x.device, dtype=x.dtype - ).transpose(1, 2) - else: - final_states_out = None - ctx.activation = activation in ["silu", "swish"] - out = causal_conv1d_fwd_function( - x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation - ) - ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) - ctx.return_final_states = return_final_states - ctx.return_dinitial_states = ( - initial_states is not None and initial_states.requires_grad - ) - return out if not return_final_states else (out, final_states_out) - - @staticmethod - def backward(ctx, dout, *args): - x, weight, bias, seq_idx, initial_states = ctx.saved_tensors - dfinal_states = args[0] if ctx.return_final_states else None - if dout.stride(2) != 1 and dout.stride(1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - # Here we just pass in None and dx will be allocated in the C++ code. - dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( - x, - weight, - bias, - dout, - seq_idx, - initial_states, - dfinal_states, - None, - ctx.return_dinitial_states, - ctx.activation, - ) - return ( - dx, - dweight, - dbias if bias is not None else None, - None, - dinitial_states if initial_states is not None else None, - None, - None, - None, - ) - - -def causal_conv1d_fn( - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - seq_idx: (batch, seqlen) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1), to be written to - activation: either None or "silu" or "swish" - - out: (batch, dim, seqlen) - """ - return CausalConv1dFn.apply( - x, - weight, - bias, - seq_idx, - initial_states, - return_final_states, - final_states_out, - activation, - ) - - -def causal_conv1d_ref( - x, - weight, - bias=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in - ) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return out if not return_final_states else (out, final_states_out) - - -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - out = causal_conv1d_update_function( - x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices - ) - if unsqueeze: - out = out.squeeze(-1) - return out - - -def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert conv_state.shape == (batch, dim, state_len) - assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state.copy_(x_new[:, :, -state_len:]) - else: - width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) - copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py deleted file mode 100644 index 8005af233d5c21b0a58917a8a18045636c2351cb..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import Tensor - -import triton -import triton.language as tl - - -@triton.jit -def _causal_conv1d_varlen_states( - X, - CU_SEQLENS, - STATES, - state_len, - dim, - stride_x_seqlen, stride_x_dim, - stride_states_batch, stride_states_seqlen, stride_states_dim, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr -): - batch_idx = tl.program_id(2) - STATES += batch_idx * stride_states_batch - end_idx = tl.load(CU_SEQLENS + batch_idx + 1) - start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) - rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) - x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, - mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), - other=0) - rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, - x, - mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) - - -def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - BLOCK_M = min(triton.next_power_of_2(state_len), 16) - BLOCK_N = min(triton.next_power_of_2(dim), 256) - grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) - with torch.cuda.device(x.device.index): - _causal_conv1d_varlen_states[grid]( - x, - cu_seqlens, - states, - state_len, - dim, - x.stride(0), x.stride(1), - states.stride(0), states.stride(2), states.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - return states - - -def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - for i in range(batch): - end_idx = cu_seqlens[i + 1] - start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) - states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T - return states diff --git a/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py deleted file mode 100644 index 1ddb9f83ceb4e9f72754fe39340738d47b6aea1b..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch - -from ._ops import ops - -def causal_conv1d_fwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - final_states_out: torch.Tensor | None, - silu_activation: bool, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_fwd( - x=x, - weight=weight, - bias=bias, - seq_idx=seq_idx, - initial_states=initial_states, - out=out, - final_states_out=final_states_out, - silu_activation=silu_activation, - ) - return out - - -def causal_conv1d_bwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - dout: torch.Tensor, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - dfinal_states: torch.Tensor | None, - dx: torch.Tensor | None, - return_dinitial_states: torch.Tensor, - silu_activation: bool, -) -> tuple[torch.Tensor | None]: - batch_size, dim = x.size()[:2] - width = weight.size(-1) - - if dx is None: - dx = torch.empty_like(x) - dweight = torch.zeros_like(weight, dtype=torch.float32) - dbias = None - if bias is not None: - dbias = torch.zeros_like(bias, dtype=torch.float32) - dinitial_states = None - if return_dinitial_states: - dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) - - ops.causal_conv1d_bwd( - x=x, - weight=weight, - bias=bias, - dout=dout, - seq_idx=seq_idx, - initial_states=initial_states, - dfinal_states=dfinal_states, - dx=dx, - dweight=dweight, - dbias=dbias, - dinitial_states=dinitial_states, - silu_activation=silu_activation, - ) - - dweight = dweight.type_as(weight) - if dbias is not None: - dbias = dbias.type_as(bias) - return dx, dweight, dbias, dinitial_states - - -def causal_conv1d_update_function( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - silu_activation: bool, - cache_seqlens: torch.Tensor | None, - conv_state_indices: torch.Tensor | None, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_update( - x=x, - conv_state=conv_state, - weight=weight, - bias=bias, - out=out, - silu_activation=silu_activation, - cache_seqlens=cache_seqlens, - conv_state_indices=conv_state_indices, - ) - return out diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py deleted file mode 100644 index e00d508d004ebe3eafe214d2b1b2ec2a44090d5c..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update -from .causal_conv1d_varlen import causal_conv1d_varlen_states - -__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so deleted file mode 100644 index 6f01490f3df550bd3906e70f05eb480e0cbf9fba..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8746e8c1e94e2022fe638316ba9cf89489d45d0d92047cafe54e554297a2c701 -size 64618464 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py deleted file mode 100644 index fa868509249a5659e14a7821104051efa3f250ad..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _causal_conv1d_1b44a8e -ops = torch.ops._causal_conv1d_1b44a8e - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py deleted file mode 100644 index d46d56c415e91e16a475a7261e01658b2259d377..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch -import torch.nn.functional as F - -from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function - - -class CausalConv1dFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, - ): - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(2) != 1 and x.stride(1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - if seq_idx is not None: - assert ( - initial_states is None - ), "initial_states must be None if seq_idx is not None" - assert ( - not return_final_states - ), "If seq_idx is not None, we don't return final_states_out" - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - if initial_states is not None and ( - initial_states.stride(2) != 1 and initial_states.stride(1) != 1 - ): - initial_states = initial_states.contiguous() - if return_final_states: - assert ( - x.stride(1) == 1 - ), "Only channel-last layout support returning final_states_out" - if final_states_out is not None: - assert ( - final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 - ) - else: - batch, dim, seqlen = x.shape - width = weight.shape[1] - final_states_out = torch.empty( - batch, width - 1, dim, device=x.device, dtype=x.dtype - ).transpose(1, 2) - else: - final_states_out = None - ctx.activation = activation in ["silu", "swish"] - out = causal_conv1d_fwd_function( - x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation - ) - ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) - ctx.return_final_states = return_final_states - ctx.return_dinitial_states = ( - initial_states is not None and initial_states.requires_grad - ) - return out if not return_final_states else (out, final_states_out) - - @staticmethod - def backward(ctx, dout, *args): - x, weight, bias, seq_idx, initial_states = ctx.saved_tensors - dfinal_states = args[0] if ctx.return_final_states else None - if dout.stride(2) != 1 and dout.stride(1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - # Here we just pass in None and dx will be allocated in the C++ code. - dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( - x, - weight, - bias, - dout, - seq_idx, - initial_states, - dfinal_states, - None, - ctx.return_dinitial_states, - ctx.activation, - ) - return ( - dx, - dweight, - dbias if bias is not None else None, - None, - dinitial_states if initial_states is not None else None, - None, - None, - None, - ) - - -def causal_conv1d_fn( - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - seq_idx: (batch, seqlen) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1), to be written to - activation: either None or "silu" or "swish" - - out: (batch, dim, seqlen) - """ - return CausalConv1dFn.apply( - x, - weight, - bias, - seq_idx, - initial_states, - return_final_states, - final_states_out, - activation, - ) - - -def causal_conv1d_ref( - x, - weight, - bias=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in - ) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return out if not return_final_states else (out, final_states_out) - - -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - out = causal_conv1d_update_function( - x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices - ) - if unsqueeze: - out = out.squeeze(-1) - return out - - -def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert conv_state.shape == (batch, dim, state_len) - assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state.copy_(x_new[:, :, -state_len:]) - else: - width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) - copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py deleted file mode 100644 index 8005af233d5c21b0a58917a8a18045636c2351cb..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import Tensor - -import triton -import triton.language as tl - - -@triton.jit -def _causal_conv1d_varlen_states( - X, - CU_SEQLENS, - STATES, - state_len, - dim, - stride_x_seqlen, stride_x_dim, - stride_states_batch, stride_states_seqlen, stride_states_dim, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr -): - batch_idx = tl.program_id(2) - STATES += batch_idx * stride_states_batch - end_idx = tl.load(CU_SEQLENS + batch_idx + 1) - start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) - rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) - x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, - mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), - other=0) - rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, - x, - mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) - - -def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - BLOCK_M = min(triton.next_power_of_2(state_len), 16) - BLOCK_N = min(triton.next_power_of_2(dim), 256) - grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) - with torch.cuda.device(x.device.index): - _causal_conv1d_varlen_states[grid]( - x, - cu_seqlens, - states, - state_len, - dim, - x.stride(0), x.stride(1), - states.stride(0), states.stride(2), states.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - return states - - -def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - for i in range(batch): - end_idx = cu_seqlens[i + 1] - start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) - states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T - return states diff --git a/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py deleted file mode 100644 index 1ddb9f83ceb4e9f72754fe39340738d47b6aea1b..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch - -from ._ops import ops - -def causal_conv1d_fwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - final_states_out: torch.Tensor | None, - silu_activation: bool, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_fwd( - x=x, - weight=weight, - bias=bias, - seq_idx=seq_idx, - initial_states=initial_states, - out=out, - final_states_out=final_states_out, - silu_activation=silu_activation, - ) - return out - - -def causal_conv1d_bwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - dout: torch.Tensor, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - dfinal_states: torch.Tensor | None, - dx: torch.Tensor | None, - return_dinitial_states: torch.Tensor, - silu_activation: bool, -) -> tuple[torch.Tensor | None]: - batch_size, dim = x.size()[:2] - width = weight.size(-1) - - if dx is None: - dx = torch.empty_like(x) - dweight = torch.zeros_like(weight, dtype=torch.float32) - dbias = None - if bias is not None: - dbias = torch.zeros_like(bias, dtype=torch.float32) - dinitial_states = None - if return_dinitial_states: - dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) - - ops.causal_conv1d_bwd( - x=x, - weight=weight, - bias=bias, - dout=dout, - seq_idx=seq_idx, - initial_states=initial_states, - dfinal_states=dfinal_states, - dx=dx, - dweight=dweight, - dbias=dbias, - dinitial_states=dinitial_states, - silu_activation=silu_activation, - ) - - dweight = dweight.type_as(weight) - if dbias is not None: - dbias = dbias.type_as(bias) - return dx, dweight, dbias, dinitial_states - - -def causal_conv1d_update_function( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - silu_activation: bool, - cache_seqlens: torch.Tensor | None, - conv_state_indices: torch.Tensor | None, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_update( - x=x, - conv_state=conv_state, - weight=weight, - bias=bias, - out=out, - silu_activation=silu_activation, - cache_seqlens=cache_seqlens, - conv_state_indices=conv_state_indices, - ) - return out diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py deleted file mode 100644 index e00d508d004ebe3eafe214d2b1b2ec2a44090d5c..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update -from .causal_conv1d_varlen import causal_conv1d_varlen_states - -__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so deleted file mode 100644 index 5ff05f9899369b601aa80c960b3f025b256bf9c6..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:49a73bdc1f6d9a32c2e107610f5ba22c2ca054a3efc1237a8291118af3191e7b -size 80684768 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py deleted file mode 100644 index fa868509249a5659e14a7821104051efa3f250ad..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _causal_conv1d_1b44a8e -ops = torch.ops._causal_conv1d_1b44a8e - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py deleted file mode 100644 index d46d56c415e91e16a475a7261e01658b2259d377..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch -import torch.nn.functional as F - -from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function - - -class CausalConv1dFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, - ): - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(2) != 1 and x.stride(1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - if seq_idx is not None: - assert ( - initial_states is None - ), "initial_states must be None if seq_idx is not None" - assert ( - not return_final_states - ), "If seq_idx is not None, we don't return final_states_out" - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - if initial_states is not None and ( - initial_states.stride(2) != 1 and initial_states.stride(1) != 1 - ): - initial_states = initial_states.contiguous() - if return_final_states: - assert ( - x.stride(1) == 1 - ), "Only channel-last layout support returning final_states_out" - if final_states_out is not None: - assert ( - final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 - ) - else: - batch, dim, seqlen = x.shape - width = weight.shape[1] - final_states_out = torch.empty( - batch, width - 1, dim, device=x.device, dtype=x.dtype - ).transpose(1, 2) - else: - final_states_out = None - ctx.activation = activation in ["silu", "swish"] - out = causal_conv1d_fwd_function( - x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation - ) - ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) - ctx.return_final_states = return_final_states - ctx.return_dinitial_states = ( - initial_states is not None and initial_states.requires_grad - ) - return out if not return_final_states else (out, final_states_out) - - @staticmethod - def backward(ctx, dout, *args): - x, weight, bias, seq_idx, initial_states = ctx.saved_tensors - dfinal_states = args[0] if ctx.return_final_states else None - if dout.stride(2) != 1 and dout.stride(1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - # Here we just pass in None and dx will be allocated in the C++ code. - dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( - x, - weight, - bias, - dout, - seq_idx, - initial_states, - dfinal_states, - None, - ctx.return_dinitial_states, - ctx.activation, - ) - return ( - dx, - dweight, - dbias if bias is not None else None, - None, - dinitial_states if initial_states is not None else None, - None, - None, - None, - ) - - -def causal_conv1d_fn( - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - seq_idx: (batch, seqlen) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1), to be written to - activation: either None or "silu" or "swish" - - out: (batch, dim, seqlen) - """ - return CausalConv1dFn.apply( - x, - weight, - bias, - seq_idx, - initial_states, - return_final_states, - final_states_out, - activation, - ) - - -def causal_conv1d_ref( - x, - weight, - bias=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in - ) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return out if not return_final_states else (out, final_states_out) - - -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - out = causal_conv1d_update_function( - x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices - ) - if unsqueeze: - out = out.squeeze(-1) - return out - - -def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert conv_state.shape == (batch, dim, state_len) - assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state.copy_(x_new[:, :, -state_len:]) - else: - width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) - copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py deleted file mode 100644 index 8005af233d5c21b0a58917a8a18045636c2351cb..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import Tensor - -import triton -import triton.language as tl - - -@triton.jit -def _causal_conv1d_varlen_states( - X, - CU_SEQLENS, - STATES, - state_len, - dim, - stride_x_seqlen, stride_x_dim, - stride_states_batch, stride_states_seqlen, stride_states_dim, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr -): - batch_idx = tl.program_id(2) - STATES += batch_idx * stride_states_batch - end_idx = tl.load(CU_SEQLENS + batch_idx + 1) - start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) - rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) - x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, - mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), - other=0) - rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, - x, - mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) - - -def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - BLOCK_M = min(triton.next_power_of_2(state_len), 16) - BLOCK_N = min(triton.next_power_of_2(dim), 256) - grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) - with torch.cuda.device(x.device.index): - _causal_conv1d_varlen_states[grid]( - x, - cu_seqlens, - states, - state_len, - dim, - x.stride(0), x.stride(1), - states.stride(0), states.stride(2), states.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - return states - - -def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - for i in range(batch): - end_idx = cu_seqlens[i + 1] - start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) - states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T - return states diff --git a/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py deleted file mode 100644 index 1ddb9f83ceb4e9f72754fe39340738d47b6aea1b..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch - -from ._ops import ops - -def causal_conv1d_fwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - final_states_out: torch.Tensor | None, - silu_activation: bool, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_fwd( - x=x, - weight=weight, - bias=bias, - seq_idx=seq_idx, - initial_states=initial_states, - out=out, - final_states_out=final_states_out, - silu_activation=silu_activation, - ) - return out - - -def causal_conv1d_bwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - dout: torch.Tensor, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - dfinal_states: torch.Tensor | None, - dx: torch.Tensor | None, - return_dinitial_states: torch.Tensor, - silu_activation: bool, -) -> tuple[torch.Tensor | None]: - batch_size, dim = x.size()[:2] - width = weight.size(-1) - - if dx is None: - dx = torch.empty_like(x) - dweight = torch.zeros_like(weight, dtype=torch.float32) - dbias = None - if bias is not None: - dbias = torch.zeros_like(bias, dtype=torch.float32) - dinitial_states = None - if return_dinitial_states: - dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) - - ops.causal_conv1d_bwd( - x=x, - weight=weight, - bias=bias, - dout=dout, - seq_idx=seq_idx, - initial_states=initial_states, - dfinal_states=dfinal_states, - dx=dx, - dweight=dweight, - dbias=dbias, - dinitial_states=dinitial_states, - silu_activation=silu_activation, - ) - - dweight = dweight.type_as(weight) - if dbias is not None: - dbias = dbias.type_as(bias) - return dx, dweight, dbias, dinitial_states - - -def causal_conv1d_update_function( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - silu_activation: bool, - cache_seqlens: torch.Tensor | None, - conv_state_indices: torch.Tensor | None, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_update( - x=x, - conv_state=conv_state, - weight=weight, - bias=bias, - out=out, - silu_activation=silu_activation, - cache_seqlens=cache_seqlens, - conv_state_indices=conv_state_indices, - ) - return out diff --git a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py deleted file mode 100644 index e00d508d004ebe3eafe214d2b1b2ec2a44090d5c..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update -from .causal_conv1d_varlen import causal_conv1d_varlen_states - -__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so deleted file mode 100644 index bc8b1dca66907aa286d15ec3bebdb510fc4f0a59..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:43ea19b486dc11d1eb780e7c1c4944ad27d27713ab41b8824b14add98c5eb645 -size 107168432 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py deleted file mode 100644 index fa868509249a5659e14a7821104051efa3f250ad..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _causal_conv1d_1b44a8e -ops = torch.ops._causal_conv1d_1b44a8e - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py deleted file mode 100644 index d46d56c415e91e16a475a7261e01658b2259d377..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch -import torch.nn.functional as F - -from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function - - -class CausalConv1dFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, - ): - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(2) != 1 and x.stride(1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - if seq_idx is not None: - assert ( - initial_states is None - ), "initial_states must be None if seq_idx is not None" - assert ( - not return_final_states - ), "If seq_idx is not None, we don't return final_states_out" - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - if initial_states is not None and ( - initial_states.stride(2) != 1 and initial_states.stride(1) != 1 - ): - initial_states = initial_states.contiguous() - if return_final_states: - assert ( - x.stride(1) == 1 - ), "Only channel-last layout support returning final_states_out" - if final_states_out is not None: - assert ( - final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 - ) - else: - batch, dim, seqlen = x.shape - width = weight.shape[1] - final_states_out = torch.empty( - batch, width - 1, dim, device=x.device, dtype=x.dtype - ).transpose(1, 2) - else: - final_states_out = None - ctx.activation = activation in ["silu", "swish"] - out = causal_conv1d_fwd_function( - x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation - ) - ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) - ctx.return_final_states = return_final_states - ctx.return_dinitial_states = ( - initial_states is not None and initial_states.requires_grad - ) - return out if not return_final_states else (out, final_states_out) - - @staticmethod - def backward(ctx, dout, *args): - x, weight, bias, seq_idx, initial_states = ctx.saved_tensors - dfinal_states = args[0] if ctx.return_final_states else None - if dout.stride(2) != 1 and dout.stride(1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - # Here we just pass in None and dx will be allocated in the C++ code. - dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( - x, - weight, - bias, - dout, - seq_idx, - initial_states, - dfinal_states, - None, - ctx.return_dinitial_states, - ctx.activation, - ) - return ( - dx, - dweight, - dbias if bias is not None else None, - None, - dinitial_states if initial_states is not None else None, - None, - None, - None, - ) - - -def causal_conv1d_fn( - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - seq_idx: (batch, seqlen) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1), to be written to - activation: either None or "silu" or "swish" - - out: (batch, dim, seqlen) - """ - return CausalConv1dFn.apply( - x, - weight, - bias, - seq_idx, - initial_states, - return_final_states, - final_states_out, - activation, - ) - - -def causal_conv1d_ref( - x, - weight, - bias=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in - ) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return out if not return_final_states else (out, final_states_out) - - -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - out = causal_conv1d_update_function( - x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices - ) - if unsqueeze: - out = out.squeeze(-1) - return out - - -def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert conv_state.shape == (batch, dim, state_len) - assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state.copy_(x_new[:, :, -state_len:]) - else: - width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) - copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py deleted file mode 100644 index 8005af233d5c21b0a58917a8a18045636c2351cb..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import Tensor - -import triton -import triton.language as tl - - -@triton.jit -def _causal_conv1d_varlen_states( - X, - CU_SEQLENS, - STATES, - state_len, - dim, - stride_x_seqlen, stride_x_dim, - stride_states_batch, stride_states_seqlen, stride_states_dim, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr -): - batch_idx = tl.program_id(2) - STATES += batch_idx * stride_states_batch - end_idx = tl.load(CU_SEQLENS + batch_idx + 1) - start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) - rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) - x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, - mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), - other=0) - rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, - x, - mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) - - -def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - BLOCK_M = min(triton.next_power_of_2(state_len), 16) - BLOCK_N = min(triton.next_power_of_2(dim), 256) - grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) - with torch.cuda.device(x.device.index): - _causal_conv1d_varlen_states[grid]( - x, - cu_seqlens, - states, - state_len, - dim, - x.stride(0), x.stride(1), - states.stride(0), states.stride(2), states.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - return states - - -def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - for i in range(batch): - end_idx = cu_seqlens[i + 1] - start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) - states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T - return states diff --git a/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py deleted file mode 100644 index 1ddb9f83ceb4e9f72754fe39340738d47b6aea1b..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch - -from ._ops import ops - -def causal_conv1d_fwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - final_states_out: torch.Tensor | None, - silu_activation: bool, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_fwd( - x=x, - weight=weight, - bias=bias, - seq_idx=seq_idx, - initial_states=initial_states, - out=out, - final_states_out=final_states_out, - silu_activation=silu_activation, - ) - return out - - -def causal_conv1d_bwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - dout: torch.Tensor, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - dfinal_states: torch.Tensor | None, - dx: torch.Tensor | None, - return_dinitial_states: torch.Tensor, - silu_activation: bool, -) -> tuple[torch.Tensor | None]: - batch_size, dim = x.size()[:2] - width = weight.size(-1) - - if dx is None: - dx = torch.empty_like(x) - dweight = torch.zeros_like(weight, dtype=torch.float32) - dbias = None - if bias is not None: - dbias = torch.zeros_like(bias, dtype=torch.float32) - dinitial_states = None - if return_dinitial_states: - dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) - - ops.causal_conv1d_bwd( - x=x, - weight=weight, - bias=bias, - dout=dout, - seq_idx=seq_idx, - initial_states=initial_states, - dfinal_states=dfinal_states, - dx=dx, - dweight=dweight, - dbias=dbias, - dinitial_states=dinitial_states, - silu_activation=silu_activation, - ) - - dweight = dweight.type_as(weight) - if dbias is not None: - dbias = dbias.type_as(bias) - return dx, dweight, dbias, dinitial_states - - -def causal_conv1d_update_function( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - silu_activation: bool, - cache_seqlens: torch.Tensor | None, - conv_state_indices: torch.Tensor | None, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_update( - x=x, - conv_state=conv_state, - weight=weight, - bias=bias, - out=out, - silu_activation=silu_activation, - cache_seqlens=cache_seqlens, - conv_state_indices=conv_state_indices, - ) - return out diff --git a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py deleted file mode 100644 index e00d508d004ebe3eafe214d2b1b2ec2a44090d5c..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update -from .causal_conv1d_varlen import causal_conv1d_varlen_states - -__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so deleted file mode 100644 index 35118540138a538805897101f81b6375e10c87b3..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1a95ffd016cdfe6f401e0e495b0083a12395ffe82b3888c510b86f4a58dfe068 -size 115140584 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py deleted file mode 100644 index fa868509249a5659e14a7821104051efa3f250ad..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _causal_conv1d_1b44a8e -ops = torch.ops._causal_conv1d_1b44a8e - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py deleted file mode 100644 index d46d56c415e91e16a475a7261e01658b2259d377..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch -import torch.nn.functional as F - -from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function - - -class CausalConv1dFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, - ): - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(2) != 1 and x.stride(1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - if seq_idx is not None: - assert ( - initial_states is None - ), "initial_states must be None if seq_idx is not None" - assert ( - not return_final_states - ), "If seq_idx is not None, we don't return final_states_out" - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - if initial_states is not None and ( - initial_states.stride(2) != 1 and initial_states.stride(1) != 1 - ): - initial_states = initial_states.contiguous() - if return_final_states: - assert ( - x.stride(1) == 1 - ), "Only channel-last layout support returning final_states_out" - if final_states_out is not None: - assert ( - final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 - ) - else: - batch, dim, seqlen = x.shape - width = weight.shape[1] - final_states_out = torch.empty( - batch, width - 1, dim, device=x.device, dtype=x.dtype - ).transpose(1, 2) - else: - final_states_out = None - ctx.activation = activation in ["silu", "swish"] - out = causal_conv1d_fwd_function( - x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation - ) - ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) - ctx.return_final_states = return_final_states - ctx.return_dinitial_states = ( - initial_states is not None and initial_states.requires_grad - ) - return out if not return_final_states else (out, final_states_out) - - @staticmethod - def backward(ctx, dout, *args): - x, weight, bias, seq_idx, initial_states = ctx.saved_tensors - dfinal_states = args[0] if ctx.return_final_states else None - if dout.stride(2) != 1 and dout.stride(1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - # Here we just pass in None and dx will be allocated in the C++ code. - dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( - x, - weight, - bias, - dout, - seq_idx, - initial_states, - dfinal_states, - None, - ctx.return_dinitial_states, - ctx.activation, - ) - return ( - dx, - dweight, - dbias if bias is not None else None, - None, - dinitial_states if initial_states is not None else None, - None, - None, - None, - ) - - -def causal_conv1d_fn( - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - seq_idx: (batch, seqlen) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1), to be written to - activation: either None or "silu" or "swish" - - out: (batch, dim, seqlen) - """ - return CausalConv1dFn.apply( - x, - weight, - bias, - seq_idx, - initial_states, - return_final_states, - final_states_out, - activation, - ) - - -def causal_conv1d_ref( - x, - weight, - bias=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in - ) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return out if not return_final_states else (out, final_states_out) - - -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - out = causal_conv1d_update_function( - x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices - ) - if unsqueeze: - out = out.squeeze(-1) - return out - - -def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert conv_state.shape == (batch, dim, state_len) - assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state.copy_(x_new[:, :, -state_len:]) - else: - width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) - copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py deleted file mode 100644 index 8005af233d5c21b0a58917a8a18045636c2351cb..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import Tensor - -import triton -import triton.language as tl - - -@triton.jit -def _causal_conv1d_varlen_states( - X, - CU_SEQLENS, - STATES, - state_len, - dim, - stride_x_seqlen, stride_x_dim, - stride_states_batch, stride_states_seqlen, stride_states_dim, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr -): - batch_idx = tl.program_id(2) - STATES += batch_idx * stride_states_batch - end_idx = tl.load(CU_SEQLENS + batch_idx + 1) - start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) - rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) - x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, - mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), - other=0) - rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, - x, - mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) - - -def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - BLOCK_M = min(triton.next_power_of_2(state_len), 16) - BLOCK_N = min(triton.next_power_of_2(dim), 256) - grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) - with torch.cuda.device(x.device.index): - _causal_conv1d_varlen_states[grid]( - x, - cu_seqlens, - states, - state_len, - dim, - x.stride(0), x.stride(1), - states.stride(0), states.stride(2), states.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - return states - - -def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - for i in range(batch): - end_idx = cu_seqlens[i + 1] - start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) - states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T - return states diff --git a/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py deleted file mode 100644 index 1ddb9f83ceb4e9f72754fe39340738d47b6aea1b..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch - -from ._ops import ops - -def causal_conv1d_fwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - final_states_out: torch.Tensor | None, - silu_activation: bool, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_fwd( - x=x, - weight=weight, - bias=bias, - seq_idx=seq_idx, - initial_states=initial_states, - out=out, - final_states_out=final_states_out, - silu_activation=silu_activation, - ) - return out - - -def causal_conv1d_bwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - dout: torch.Tensor, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - dfinal_states: torch.Tensor | None, - dx: torch.Tensor | None, - return_dinitial_states: torch.Tensor, - silu_activation: bool, -) -> tuple[torch.Tensor | None]: - batch_size, dim = x.size()[:2] - width = weight.size(-1) - - if dx is None: - dx = torch.empty_like(x) - dweight = torch.zeros_like(weight, dtype=torch.float32) - dbias = None - if bias is not None: - dbias = torch.zeros_like(bias, dtype=torch.float32) - dinitial_states = None - if return_dinitial_states: - dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) - - ops.causal_conv1d_bwd( - x=x, - weight=weight, - bias=bias, - dout=dout, - seq_idx=seq_idx, - initial_states=initial_states, - dfinal_states=dfinal_states, - dx=dx, - dweight=dweight, - dbias=dbias, - dinitial_states=dinitial_states, - silu_activation=silu_activation, - ) - - dweight = dweight.type_as(weight) - if dbias is not None: - dbias = dbias.type_as(bias) - return dx, dweight, dbias, dinitial_states - - -def causal_conv1d_update_function( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - silu_activation: bool, - cache_seqlens: torch.Tensor | None, - conv_state_indices: torch.Tensor | None, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_update( - x=x, - conv_state=conv_state, - weight=weight, - bias=bias, - out=out, - silu_activation=silu_activation, - cache_seqlens=cache_seqlens, - conv_state_indices=conv_state_indices, - ) - return out diff --git a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py deleted file mode 100644 index e00d508d004ebe3eafe214d2b1b2ec2a44090d5c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update -from .causal_conv1d_varlen import causal_conv1d_varlen_states - -__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so deleted file mode 100644 index 4b8e537aec32c521a76d7f422abf321ba932f4ae..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1652a695f38a12463ece0e84007e34575177c678c2432e97e1510064ea6b627a -size 80684856 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py deleted file mode 100644 index fa868509249a5659e14a7821104051efa3f250ad..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _causal_conv1d_1b44a8e -ops = torch.ops._causal_conv1d_1b44a8e - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py deleted file mode 100644 index d46d56c415e91e16a475a7261e01658b2259d377..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch -import torch.nn.functional as F - -from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function - - -class CausalConv1dFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, - ): - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(2) != 1 and x.stride(1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - if seq_idx is not None: - assert ( - initial_states is None - ), "initial_states must be None if seq_idx is not None" - assert ( - not return_final_states - ), "If seq_idx is not None, we don't return final_states_out" - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - if initial_states is not None and ( - initial_states.stride(2) != 1 and initial_states.stride(1) != 1 - ): - initial_states = initial_states.contiguous() - if return_final_states: - assert ( - x.stride(1) == 1 - ), "Only channel-last layout support returning final_states_out" - if final_states_out is not None: - assert ( - final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 - ) - else: - batch, dim, seqlen = x.shape - width = weight.shape[1] - final_states_out = torch.empty( - batch, width - 1, dim, device=x.device, dtype=x.dtype - ).transpose(1, 2) - else: - final_states_out = None - ctx.activation = activation in ["silu", "swish"] - out = causal_conv1d_fwd_function( - x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation - ) - ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) - ctx.return_final_states = return_final_states - ctx.return_dinitial_states = ( - initial_states is not None and initial_states.requires_grad - ) - return out if not return_final_states else (out, final_states_out) - - @staticmethod - def backward(ctx, dout, *args): - x, weight, bias, seq_idx, initial_states = ctx.saved_tensors - dfinal_states = args[0] if ctx.return_final_states else None - if dout.stride(2) != 1 and dout.stride(1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - # Here we just pass in None and dx will be allocated in the C++ code. - dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( - x, - weight, - bias, - dout, - seq_idx, - initial_states, - dfinal_states, - None, - ctx.return_dinitial_states, - ctx.activation, - ) - return ( - dx, - dweight, - dbias if bias is not None else None, - None, - dinitial_states if initial_states is not None else None, - None, - None, - None, - ) - - -def causal_conv1d_fn( - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - seq_idx: (batch, seqlen) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1), to be written to - activation: either None or "silu" or "swish" - - out: (batch, dim, seqlen) - """ - return CausalConv1dFn.apply( - x, - weight, - bias, - seq_idx, - initial_states, - return_final_states, - final_states_out, - activation, - ) - - -def causal_conv1d_ref( - x, - weight, - bias=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in - ) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return out if not return_final_states else (out, final_states_out) - - -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - out = causal_conv1d_update_function( - x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices - ) - if unsqueeze: - out = out.squeeze(-1) - return out - - -def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert conv_state.shape == (batch, dim, state_len) - assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state.copy_(x_new[:, :, -state_len:]) - else: - width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) - copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py deleted file mode 100644 index 8005af233d5c21b0a58917a8a18045636c2351cb..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import Tensor - -import triton -import triton.language as tl - - -@triton.jit -def _causal_conv1d_varlen_states( - X, - CU_SEQLENS, - STATES, - state_len, - dim, - stride_x_seqlen, stride_x_dim, - stride_states_batch, stride_states_seqlen, stride_states_dim, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr -): - batch_idx = tl.program_id(2) - STATES += batch_idx * stride_states_batch - end_idx = tl.load(CU_SEQLENS + batch_idx + 1) - start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) - rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) - x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, - mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), - other=0) - rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, - x, - mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) - - -def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - BLOCK_M = min(triton.next_power_of_2(state_len), 16) - BLOCK_N = min(triton.next_power_of_2(dim), 256) - grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) - with torch.cuda.device(x.device.index): - _causal_conv1d_varlen_states[grid]( - x, - cu_seqlens, - states, - state_len, - dim, - x.stride(0), x.stride(1), - states.stride(0), states.stride(2), states.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - return states - - -def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - for i in range(batch): - end_idx = cu_seqlens[i + 1] - start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) - states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T - return states diff --git a/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py deleted file mode 100644 index 1ddb9f83ceb4e9f72754fe39340738d47b6aea1b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch - -from ._ops import ops - -def causal_conv1d_fwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - final_states_out: torch.Tensor | None, - silu_activation: bool, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_fwd( - x=x, - weight=weight, - bias=bias, - seq_idx=seq_idx, - initial_states=initial_states, - out=out, - final_states_out=final_states_out, - silu_activation=silu_activation, - ) - return out - - -def causal_conv1d_bwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - dout: torch.Tensor, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - dfinal_states: torch.Tensor | None, - dx: torch.Tensor | None, - return_dinitial_states: torch.Tensor, - silu_activation: bool, -) -> tuple[torch.Tensor | None]: - batch_size, dim = x.size()[:2] - width = weight.size(-1) - - if dx is None: - dx = torch.empty_like(x) - dweight = torch.zeros_like(weight, dtype=torch.float32) - dbias = None - if bias is not None: - dbias = torch.zeros_like(bias, dtype=torch.float32) - dinitial_states = None - if return_dinitial_states: - dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) - - ops.causal_conv1d_bwd( - x=x, - weight=weight, - bias=bias, - dout=dout, - seq_idx=seq_idx, - initial_states=initial_states, - dfinal_states=dfinal_states, - dx=dx, - dweight=dweight, - dbias=dbias, - dinitial_states=dinitial_states, - silu_activation=silu_activation, - ) - - dweight = dweight.type_as(weight) - if dbias is not None: - dbias = dbias.type_as(bias) - return dx, dweight, dbias, dinitial_states - - -def causal_conv1d_update_function( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - silu_activation: bool, - cache_seqlens: torch.Tensor | None, - conv_state_indices: torch.Tensor | None, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_update( - x=x, - conv_state=conv_state, - weight=weight, - bias=bias, - out=out, - silu_activation=silu_activation, - cache_seqlens=cache_seqlens, - conv_state_indices=conv_state_indices, - ) - return out diff --git a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py deleted file mode 100644 index e00d508d004ebe3eafe214d2b1b2ec2a44090d5c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update -from .causal_conv1d_varlen import causal_conv1d_varlen_states - -__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so deleted file mode 100644 index 27e888e460640fd2abf04ca37f69d56f0322a3f0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ae14d37443fbdadb3ffb7b5a4b6d46c01d7b17a965938ad9f6314dfa575c58de -size 107172616 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py deleted file mode 100644 index fa868509249a5659e14a7821104051efa3f250ad..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _causal_conv1d_1b44a8e -ops = torch.ops._causal_conv1d_1b44a8e - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py deleted file mode 100644 index d46d56c415e91e16a475a7261e01658b2259d377..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch -import torch.nn.functional as F - -from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function - - -class CausalConv1dFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, - ): - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(2) != 1 and x.stride(1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - if seq_idx is not None: - assert ( - initial_states is None - ), "initial_states must be None if seq_idx is not None" - assert ( - not return_final_states - ), "If seq_idx is not None, we don't return final_states_out" - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - if initial_states is not None and ( - initial_states.stride(2) != 1 and initial_states.stride(1) != 1 - ): - initial_states = initial_states.contiguous() - if return_final_states: - assert ( - x.stride(1) == 1 - ), "Only channel-last layout support returning final_states_out" - if final_states_out is not None: - assert ( - final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 - ) - else: - batch, dim, seqlen = x.shape - width = weight.shape[1] - final_states_out = torch.empty( - batch, width - 1, dim, device=x.device, dtype=x.dtype - ).transpose(1, 2) - else: - final_states_out = None - ctx.activation = activation in ["silu", "swish"] - out = causal_conv1d_fwd_function( - x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation - ) - ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) - ctx.return_final_states = return_final_states - ctx.return_dinitial_states = ( - initial_states is not None and initial_states.requires_grad - ) - return out if not return_final_states else (out, final_states_out) - - @staticmethod - def backward(ctx, dout, *args): - x, weight, bias, seq_idx, initial_states = ctx.saved_tensors - dfinal_states = args[0] if ctx.return_final_states else None - if dout.stride(2) != 1 and dout.stride(1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - # Here we just pass in None and dx will be allocated in the C++ code. - dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( - x, - weight, - bias, - dout, - seq_idx, - initial_states, - dfinal_states, - None, - ctx.return_dinitial_states, - ctx.activation, - ) - return ( - dx, - dweight, - dbias if bias is not None else None, - None, - dinitial_states if initial_states is not None else None, - None, - None, - None, - ) - - -def causal_conv1d_fn( - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - seq_idx: (batch, seqlen) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1), to be written to - activation: either None or "silu" or "swish" - - out: (batch, dim, seqlen) - """ - return CausalConv1dFn.apply( - x, - weight, - bias, - seq_idx, - initial_states, - return_final_states, - final_states_out, - activation, - ) - - -def causal_conv1d_ref( - x, - weight, - bias=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in - ) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return out if not return_final_states else (out, final_states_out) - - -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - out = causal_conv1d_update_function( - x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices - ) - if unsqueeze: - out = out.squeeze(-1) - return out - - -def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert conv_state.shape == (batch, dim, state_len) - assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state.copy_(x_new[:, :, -state_len:]) - else: - width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) - copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py deleted file mode 100644 index 8005af233d5c21b0a58917a8a18045636c2351cb..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import Tensor - -import triton -import triton.language as tl - - -@triton.jit -def _causal_conv1d_varlen_states( - X, - CU_SEQLENS, - STATES, - state_len, - dim, - stride_x_seqlen, stride_x_dim, - stride_states_batch, stride_states_seqlen, stride_states_dim, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr -): - batch_idx = tl.program_id(2) - STATES += batch_idx * stride_states_batch - end_idx = tl.load(CU_SEQLENS + batch_idx + 1) - start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) - rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) - x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, - mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), - other=0) - rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, - x, - mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) - - -def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - BLOCK_M = min(triton.next_power_of_2(state_len), 16) - BLOCK_N = min(triton.next_power_of_2(dim), 256) - grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) - with torch.cuda.device(x.device.index): - _causal_conv1d_varlen_states[grid]( - x, - cu_seqlens, - states, - state_len, - dim, - x.stride(0), x.stride(1), - states.stride(0), states.stride(2), states.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - return states - - -def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - for i in range(batch): - end_idx = cu_seqlens[i + 1] - start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) - states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T - return states diff --git a/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py deleted file mode 100644 index 1ddb9f83ceb4e9f72754fe39340738d47b6aea1b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch - -from ._ops import ops - -def causal_conv1d_fwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - final_states_out: torch.Tensor | None, - silu_activation: bool, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_fwd( - x=x, - weight=weight, - bias=bias, - seq_idx=seq_idx, - initial_states=initial_states, - out=out, - final_states_out=final_states_out, - silu_activation=silu_activation, - ) - return out - - -def causal_conv1d_bwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - dout: torch.Tensor, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - dfinal_states: torch.Tensor | None, - dx: torch.Tensor | None, - return_dinitial_states: torch.Tensor, - silu_activation: bool, -) -> tuple[torch.Tensor | None]: - batch_size, dim = x.size()[:2] - width = weight.size(-1) - - if dx is None: - dx = torch.empty_like(x) - dweight = torch.zeros_like(weight, dtype=torch.float32) - dbias = None - if bias is not None: - dbias = torch.zeros_like(bias, dtype=torch.float32) - dinitial_states = None - if return_dinitial_states: - dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) - - ops.causal_conv1d_bwd( - x=x, - weight=weight, - bias=bias, - dout=dout, - seq_idx=seq_idx, - initial_states=initial_states, - dfinal_states=dfinal_states, - dx=dx, - dweight=dweight, - dbias=dbias, - dinitial_states=dinitial_states, - silu_activation=silu_activation, - ) - - dweight = dweight.type_as(weight) - if dbias is not None: - dbias = dbias.type_as(bias) - return dx, dweight, dbias, dinitial_states - - -def causal_conv1d_update_function( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - silu_activation: bool, - cache_seqlens: torch.Tensor | None, - conv_state_indices: torch.Tensor | None, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_update( - x=x, - conv_state=conv_state, - weight=weight, - bias=bias, - out=out, - silu_activation=silu_activation, - cache_seqlens=cache_seqlens, - conv_state_indices=conv_state_indices, - ) - return out diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py deleted file mode 100644 index e00d508d004ebe3eafe214d2b1b2ec2a44090d5c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update -from .causal_conv1d_varlen import causal_conv1d_varlen_states - -__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so deleted file mode 100644 index 4df07e840976b195056bfc406e26080680d4f083..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0cc8ad33b1164913858ea1cffce2f21f6f74d508da4cdd07de7d54aeb1c28ecc -size 64613056 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py deleted file mode 100644 index fa868509249a5659e14a7821104051efa3f250ad..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _causal_conv1d_1b44a8e -ops = torch.ops._causal_conv1d_1b44a8e - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_causal_conv1d_1b44a8e::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py deleted file mode 100644 index 03dbc1afe1cf156661a2b1b22003cd5f599a0309..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import sys - -import importlib -from pathlib import Path -from types import ModuleType - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py deleted file mode 100644 index d46d56c415e91e16a475a7261e01658b2259d377..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch -import torch.nn.functional as F - -from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function - - -class CausalConv1dFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, - ): - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(2) != 1 and x.stride(1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - if seq_idx is not None: - assert ( - initial_states is None - ), "initial_states must be None if seq_idx is not None" - assert ( - not return_final_states - ), "If seq_idx is not None, we don't return final_states_out" - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - if initial_states is not None and ( - initial_states.stride(2) != 1 and initial_states.stride(1) != 1 - ): - initial_states = initial_states.contiguous() - if return_final_states: - assert ( - x.stride(1) == 1 - ), "Only channel-last layout support returning final_states_out" - if final_states_out is not None: - assert ( - final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 - ) - else: - batch, dim, seqlen = x.shape - width = weight.shape[1] - final_states_out = torch.empty( - batch, width - 1, dim, device=x.device, dtype=x.dtype - ).transpose(1, 2) - else: - final_states_out = None - ctx.activation = activation in ["silu", "swish"] - out = causal_conv1d_fwd_function( - x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation - ) - ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) - ctx.return_final_states = return_final_states - ctx.return_dinitial_states = ( - initial_states is not None and initial_states.requires_grad - ) - return out if not return_final_states else (out, final_states_out) - - @staticmethod - def backward(ctx, dout, *args): - x, weight, bias, seq_idx, initial_states = ctx.saved_tensors - dfinal_states = args[0] if ctx.return_final_states else None - if dout.stride(2) != 1 and dout.stride(1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - # Here we just pass in None and dx will be allocated in the C++ code. - dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( - x, - weight, - bias, - dout, - seq_idx, - initial_states, - dfinal_states, - None, - ctx.return_dinitial_states, - ctx.activation, - ) - return ( - dx, - dweight, - dbias if bias is not None else None, - None, - dinitial_states if initial_states is not None else None, - None, - None, - None, - ) - - -def causal_conv1d_fn( - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - seq_idx: (batch, seqlen) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1), to be written to - activation: either None or "silu" or "swish" - - out: (batch, dim, seqlen) - """ - return CausalConv1dFn.apply( - x, - weight, - bias, - seq_idx, - initial_states, - return_final_states, - final_states_out, - activation, - ) - - -def causal_conv1d_ref( - x, - weight, - bias=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in - ) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return out if not return_final_states else (out, final_states_out) - - -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - out = causal_conv1d_update_function( - x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices - ) - if unsqueeze: - out = out.squeeze(-1) - return out - - -def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert conv_state.shape == (batch, dim, state_len) - assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state.copy_(x_new[:, :, -state_len:]) - else: - width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) - copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) - conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py deleted file mode 100644 index 8005af233d5c21b0a58917a8a18045636c2351cb..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import Tensor - -import triton -import triton.language as tl - - -@triton.jit -def _causal_conv1d_varlen_states( - X, - CU_SEQLENS, - STATES, - state_len, - dim, - stride_x_seqlen, stride_x_dim, - stride_states_batch, stride_states_seqlen, stride_states_dim, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr -): - batch_idx = tl.program_id(2) - STATES += batch_idx * stride_states_batch - end_idx = tl.load(CU_SEQLENS + batch_idx + 1) - start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) - rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) - x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, - mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), - other=0) - rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, - x, - mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) - - -def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - BLOCK_M = min(triton.next_power_of_2(state_len), 16) - BLOCK_N = min(triton.next_power_of_2(dim), 256) - grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) - with torch.cuda.device(x.device.index): - _causal_conv1d_varlen_states[grid]( - x, - cu_seqlens, - states, - state_len, - dim, - x.stride(0), x.stride(1), - states.stride(0), states.stride(2), states.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - return states - - -def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) - for i in range(batch): - end_idx = cu_seqlens[i + 1] - start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) - states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T - return states diff --git a/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py deleted file mode 100644 index 1ddb9f83ceb4e9f72754fe39340738d47b6aea1b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import torch - -from ._ops import ops - -def causal_conv1d_fwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - final_states_out: torch.Tensor | None, - silu_activation: bool, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_fwd( - x=x, - weight=weight, - bias=bias, - seq_idx=seq_idx, - initial_states=initial_states, - out=out, - final_states_out=final_states_out, - silu_activation=silu_activation, - ) - return out - - -def causal_conv1d_bwd_function( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - dout: torch.Tensor, - seq_idx: torch.Tensor | None, - initial_states: torch.Tensor | None, - dfinal_states: torch.Tensor | None, - dx: torch.Tensor | None, - return_dinitial_states: torch.Tensor, - silu_activation: bool, -) -> tuple[torch.Tensor | None]: - batch_size, dim = x.size()[:2] - width = weight.size(-1) - - if dx is None: - dx = torch.empty_like(x) - dweight = torch.zeros_like(weight, dtype=torch.float32) - dbias = None - if bias is not None: - dbias = torch.zeros_like(bias, dtype=torch.float32) - dinitial_states = None - if return_dinitial_states: - dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) - - ops.causal_conv1d_bwd( - x=x, - weight=weight, - bias=bias, - dout=dout, - seq_idx=seq_idx, - initial_states=initial_states, - dfinal_states=dfinal_states, - dx=dx, - dweight=dweight, - dbias=dbias, - dinitial_states=dinitial_states, - silu_activation=silu_activation, - ) - - dweight = dweight.type_as(weight) - if dbias is not None: - dbias = dbias.type_as(bias) - return dx, dweight, dbias, dinitial_states - - -def causal_conv1d_update_function( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - silu_activation: bool, - cache_seqlens: torch.Tensor | None, - conv_state_indices: torch.Tensor | None, -) -> torch.Tensor: - out = torch.empty_like(x) - ops.causal_conv1d_update( - x=x, - conv_state=conv_state, - weight=weight, - bias=bias, - out=out, - silu_activation=silu_activation, - cache_seqlens=cache_seqlens, - conv_state_indices=conv_state_indices, - ) - return out diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json deleted file mode 100644 index 76bafa5f33b6818aa6bb4cab04be811b87519b44..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"python-depends":[]} \ No newline at end of file