Kernels
danieldk HF Staff commited on
Commit
1a7b769
·
verified ·
1 Parent(s): 48f2c67

Build uploaded using `kernels`.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch210-cxx11-cu126-x86_64-linux/__init__.py +0 -4
  2. build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +0 -3
  3. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +0 -9
  4. build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py +0 -26
  5. build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py +0 -242
  6. build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py +0 -86
  7. build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py +0 -96
  8. build/torch210-cxx11-cu126-x86_64-linux/metadata.json +0 -1
  9. build/torch210-cxx11-cu128-x86_64-linux/__init__.py +0 -4
  10. build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +0 -3
  11. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +0 -9
  12. build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py +0 -26
  13. build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py +0 -242
  14. build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py +0 -86
  15. build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py +0 -96
  16. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +0 -1
  17. build/torch210-cxx11-cu130-x86_64-linux/__init__.py +0 -4
  18. build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +0 -3
  19. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +0 -9
  20. build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py +0 -26
  21. build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py +0 -242
  22. build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py +0 -86
  23. build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py +0 -96
  24. build/torch210-cxx11-cu130-x86_64-linux/metadata.json +0 -1
  25. build/torch28-cxx11-cu126-x86_64-linux/__init__.py +0 -4
  26. build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +0 -3
  27. build/torch28-cxx11-cu126-x86_64-linux/_ops.py +0 -9
  28. build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py +0 -26
  29. build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py +0 -242
  30. build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py +0 -86
  31. build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py +0 -96
  32. build/torch28-cxx11-cu126-x86_64-linux/metadata.json +0 -1
  33. build/torch28-cxx11-cu128-x86_64-linux/__init__.py +0 -4
  34. build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +0 -3
  35. build/torch28-cxx11-cu128-x86_64-linux/_ops.py +0 -9
  36. build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py +0 -26
  37. build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py +0 -242
  38. build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py +0 -86
  39. build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py +0 -96
  40. build/torch28-cxx11-cu128-x86_64-linux/metadata.json +0 -1
  41. build/torch28-cxx11-cu129-x86_64-linux/__init__.py +0 -4
  42. build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +0 -3
  43. build/torch28-cxx11-cu129-x86_64-linux/_ops.py +0 -9
  44. build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py +0 -26
  45. build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py +0 -242
  46. build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py +0 -86
  47. build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py +0 -96
  48. build/torch28-cxx11-cu129-x86_64-linux/metadata.json +0 -1
  49. build/torch29-cxx11-cu126-x86_64-linux/__init__.py +0 -4
  50. build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so +0 -3
build/torch210-cxx11-cu126-x86_64-linux/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
- from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
-
4
- __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6eb0fdb8827538d27d0822e22dd968059657aafdd8dca77b99d606e0026ae43b
3
- size 80694456
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _causal_conv1d_1b44a8e
3
- ops = torch.ops._causal_conv1d_1b44a8e
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_causal_conv1d_1b44a8e::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py DELETED
@@ -1,242 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
7
-
8
-
9
- class CausalConv1dFn(torch.autograd.Function):
10
- @staticmethod
11
- def forward(
12
- ctx,
13
- x,
14
- weight,
15
- bias=None,
16
- seq_idx=None,
17
- initial_states=None,
18
- return_final_states=False,
19
- final_states_out=None,
20
- activation=None,
21
- ):
22
- if activation not in [None, "silu", "swish"]:
23
- raise NotImplementedError("activation must be None, silu, or swish")
24
- if x.stride(2) != 1 and x.stride(1) != 1:
25
- x = x.contiguous()
26
- bias = bias.contiguous() if bias is not None else None
27
- if seq_idx is not None:
28
- assert (
29
- initial_states is None
30
- ), "initial_states must be None if seq_idx is not None"
31
- assert (
32
- not return_final_states
33
- ), "If seq_idx is not None, we don't return final_states_out"
34
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
35
- if initial_states is not None and (
36
- initial_states.stride(2) != 1 and initial_states.stride(1) != 1
37
- ):
38
- initial_states = initial_states.contiguous()
39
- if return_final_states:
40
- assert (
41
- x.stride(1) == 1
42
- ), "Only channel-last layout support returning final_states_out"
43
- if final_states_out is not None:
44
- assert (
45
- final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
46
- )
47
- else:
48
- batch, dim, seqlen = x.shape
49
- width = weight.shape[1]
50
- final_states_out = torch.empty(
51
- batch, width - 1, dim, device=x.device, dtype=x.dtype
52
- ).transpose(1, 2)
53
- else:
54
- final_states_out = None
55
- ctx.activation = activation in ["silu", "swish"]
56
- out = causal_conv1d_fwd_function(
57
- x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
58
- )
59
- ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
60
- ctx.return_final_states = return_final_states
61
- ctx.return_dinitial_states = (
62
- initial_states is not None and initial_states.requires_grad
63
- )
64
- return out if not return_final_states else (out, final_states_out)
65
-
66
- @staticmethod
67
- def backward(ctx, dout, *args):
68
- x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
69
- dfinal_states = args[0] if ctx.return_final_states else None
70
- if dout.stride(2) != 1 and dout.stride(1) != 1:
71
- dout = dout.contiguous()
72
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
73
- # backward of conv1d with the backward of chunk).
74
- # Here we just pass in None and dx will be allocated in the C++ code.
75
- dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
76
- x,
77
- weight,
78
- bias,
79
- dout,
80
- seq_idx,
81
- initial_states,
82
- dfinal_states,
83
- None,
84
- ctx.return_dinitial_states,
85
- ctx.activation,
86
- )
87
- return (
88
- dx,
89
- dweight,
90
- dbias if bias is not None else None,
91
- None,
92
- dinitial_states if initial_states is not None else None,
93
- None,
94
- None,
95
- None,
96
- )
97
-
98
-
99
- def causal_conv1d_fn(
100
- x,
101
- weight,
102
- bias=None,
103
- seq_idx=None,
104
- initial_states=None,
105
- return_final_states=False,
106
- final_states_out=None,
107
- activation=None,
108
- ):
109
- """
110
- x: (batch, dim, seqlen)
111
- weight: (dim, width)
112
- bias: (dim,)
113
- seq_idx: (batch, seqlen)
114
- initial_states: (batch, dim, width - 1)
115
- final_states_out: (batch, dim, width - 1), to be written to
116
- activation: either None or "silu" or "swish"
117
-
118
- out: (batch, dim, seqlen)
119
- """
120
- return CausalConv1dFn.apply(
121
- x,
122
- weight,
123
- bias,
124
- seq_idx,
125
- initial_states,
126
- return_final_states,
127
- final_states_out,
128
- activation,
129
- )
130
-
131
-
132
- def causal_conv1d_ref(
133
- x,
134
- weight,
135
- bias=None,
136
- initial_states=None,
137
- return_final_states=False,
138
- final_states_out=None,
139
- activation=None,
140
- ):
141
- """
142
- x: (batch, dim, seqlen)
143
- weight: (dim, width)
144
- bias: (dim,)
145
- initial_states: (batch, dim, width - 1)
146
- final_states_out: (batch, dim, width - 1)
147
-
148
- out: (batch, dim, seqlen)
149
- """
150
- if activation not in [None, "silu", "swish"]:
151
- raise NotImplementedError("activation must be None, silu, or swish")
152
- dtype_in = x.dtype
153
- x = x.to(weight.dtype)
154
- seqlen = x.shape[-1]
155
- dim, width = weight.shape
156
- if initial_states is None:
157
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
158
- else:
159
- x = torch.cat([initial_states, x], dim=-1)
160
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
161
- out = out[..., :seqlen]
162
- if return_final_states:
163
- final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
164
- dtype_in
165
- ) # (batch, dim, width - 1)
166
- if final_states_out is not None:
167
- final_states_out.copy_(final_states)
168
- else:
169
- final_states_out = final_states
170
- out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
171
- return out if not return_final_states else (out, final_states_out)
172
-
173
-
174
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
175
- """
176
- x: (batch, dim) or (batch, dim, seqlen)
177
- conv_state: (batch, dim, state_len), where state_len >= width - 1
178
- weight: (dim, width)
179
- bias: (dim,)
180
- cache_seqlens: (batch,), dtype int32.
181
- If not None, the conv_state is treated as a circular buffer.
182
- The conv_state will be updated by copying x to the conv_state starting at the index
183
- @cache_seqlens % state_len.
184
- conv_state_indices: (batch,), dtype int32
185
- If None, the conv_state is a larger tensor along the batch dim,
186
- and we are selecting the batch coords specified by conv_state_indices.
187
- Useful for a continuous batching scenario.
188
-
189
- out: (batch, dim) or (batch, dim, seqlen)
190
- """
191
- if activation not in [None, "silu", "swish"]:
192
- raise NotImplementedError("activation must be None, silu, or swish")
193
- activation = activation in ["silu", "swish"]
194
- unsqueeze = x.dim() == 2
195
- if unsqueeze:
196
- x = x.unsqueeze(-1)
197
- out = causal_conv1d_update_function(
198
- x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
199
- )
200
- if unsqueeze:
201
- out = out.squeeze(-1)
202
- return out
203
-
204
-
205
- def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
206
- """
207
- x: (batch, dim) or (batch, dim, seqlen)
208
- conv_state: (batch, dim, state_len), where state_len >= width - 1
209
- weight: (dim, width)
210
- bias: (dim,)
211
- cache_seqlens: (batch,), dtype int32.
212
- If not None, the conv_state is treated as a circular buffer.
213
- The conv_state will be updated by copying x to the conv_state starting at the index
214
- @cache_seqlens % state_len before performing the convolution.
215
-
216
- out: (batch, dim) or (batch, dim, seqlen)
217
- """
218
- if activation not in [None, "silu", "swish"]:
219
- raise NotImplementedError("activation must be None, silu, or swish")
220
- dtype_in = x.dtype
221
- unsqueeze = x.dim() == 2
222
- if unsqueeze:
223
- x = x.unsqueeze(-1)
224
- batch, dim, seqlen = x.shape
225
- width = weight.shape[1]
226
- state_len = conv_state.shape[-1]
227
- assert conv_state.shape == (batch, dim, state_len)
228
- assert weight.shape == (dim, width)
229
- if cache_seqlens is None:
230
- x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
231
- conv_state.copy_(x_new[:, :, -state_len:])
232
- else:
233
- width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
- width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
- x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
236
- copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
237
- copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
238
- conv_state.scatter_(2, copy_idx, x)
239
- out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
240
- if unsqueeze:
241
- out = out.squeeze(-1)
242
- return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py DELETED
@@ -1,86 +0,0 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- import triton
5
- import triton.language as tl
6
-
7
-
8
- @triton.jit
9
- def _causal_conv1d_varlen_states(
10
- X,
11
- CU_SEQLENS,
12
- STATES,
13
- state_len,
14
- dim,
15
- stride_x_seqlen, stride_x_dim,
16
- stride_states_batch, stride_states_seqlen, stride_states_dim,
17
- BLOCK_M: tl.constexpr,
18
- BLOCK_N: tl.constexpr
19
- ):
20
- batch_idx = tl.program_id(2)
21
- STATES += batch_idx * stride_states_batch
22
- end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
- start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
- rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
- cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
- x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
- mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
- other=0)
29
- rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
- tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
- x,
32
- mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
-
34
-
35
- def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
- """
37
- Forward pass only, does not support backward pass.
38
- Parameters:
39
- x: (total_tokens, dim)
40
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
- If some of those elements belong to a different sequence, the value of the states will be zero.
43
- Return:
44
- states: (batch, dim, state_len)
45
- """
46
- _, dim = x.shape
47
- batch = cu_seqlens.shape[0] - 1
48
- cu_seqlens = cu_seqlens.contiguous()
49
- states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
- BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
- BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
- grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
- with torch.cuda.device(x.device.index):
54
- _causal_conv1d_varlen_states[grid](
55
- x,
56
- cu_seqlens,
57
- states,
58
- state_len,
59
- dim,
60
- x.stride(0), x.stride(1),
61
- states.stride(0), states.stride(2), states.stride(1),
62
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
- )
64
- return states
65
-
66
-
67
- def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
- """
69
- Forward pass only, does not support backward pass.
70
- Parameters:
71
- x: (total_tokens, dim)
72
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
- If some of those elements belong to a different sequence, the value of the states will be zero.
75
- Return:
76
- states: (batch, dim, state_len)
77
- """
78
- _, dim = x.shape
79
- batch = cu_seqlens.shape[0] - 1
80
- cu_seqlens = cu_seqlens.contiguous()
81
- states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
- for i in range(batch):
83
- end_idx = cu_seqlens[i + 1]
84
- start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
- states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
- return states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py DELETED
@@ -1,96 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
-
5
- from ._ops import ops
6
-
7
- def causal_conv1d_fwd_function(
8
- x: torch.Tensor,
9
- weight: torch.Tensor,
10
- bias: torch.Tensor | None,
11
- seq_idx: torch.Tensor | None,
12
- initial_states: torch.Tensor | None,
13
- final_states_out: torch.Tensor | None,
14
- silu_activation: bool,
15
- ) -> torch.Tensor:
16
- out = torch.empty_like(x)
17
- ops.causal_conv1d_fwd(
18
- x=x,
19
- weight=weight,
20
- bias=bias,
21
- seq_idx=seq_idx,
22
- initial_states=initial_states,
23
- out=out,
24
- final_states_out=final_states_out,
25
- silu_activation=silu_activation,
26
- )
27
- return out
28
-
29
-
30
- def causal_conv1d_bwd_function(
31
- x: torch.Tensor,
32
- weight: torch.Tensor,
33
- bias: torch.Tensor | None,
34
- dout: torch.Tensor,
35
- seq_idx: torch.Tensor | None,
36
- initial_states: torch.Tensor | None,
37
- dfinal_states: torch.Tensor | None,
38
- dx: torch.Tensor | None,
39
- return_dinitial_states: torch.Tensor,
40
- silu_activation: bool,
41
- ) -> tuple[torch.Tensor | None]:
42
- batch_size, dim = x.size()[:2]
43
- width = weight.size(-1)
44
-
45
- if dx is None:
46
- dx = torch.empty_like(x)
47
- dweight = torch.zeros_like(weight, dtype=torch.float32)
48
- dbias = None
49
- if bias is not None:
50
- dbias = torch.zeros_like(bias, dtype=torch.float32)
51
- dinitial_states = None
52
- if return_dinitial_states:
53
- dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
54
-
55
- ops.causal_conv1d_bwd(
56
- x=x,
57
- weight=weight,
58
- bias=bias,
59
- dout=dout,
60
- seq_idx=seq_idx,
61
- initial_states=initial_states,
62
- dfinal_states=dfinal_states,
63
- dx=dx,
64
- dweight=dweight,
65
- dbias=dbias,
66
- dinitial_states=dinitial_states,
67
- silu_activation=silu_activation,
68
- )
69
-
70
- dweight = dweight.type_as(weight)
71
- if dbias is not None:
72
- dbias = dbias.type_as(bias)
73
- return dx, dweight, dbias, dinitial_states
74
-
75
-
76
- def causal_conv1d_update_function(
77
- x: torch.Tensor,
78
- conv_state: torch.Tensor,
79
- weight: torch.Tensor,
80
- bias: torch.Tensor | None,
81
- silu_activation: bool,
82
- cache_seqlens: torch.Tensor | None,
83
- conv_state_indices: torch.Tensor | None,
84
- ) -> torch.Tensor:
85
- out = torch.empty_like(x)
86
- ops.causal_conv1d_update(
87
- x=x,
88
- conv_state=conv_state,
89
- weight=weight,
90
- bias=bias,
91
- out=out,
92
- silu_activation=silu_activation,
93
- cache_seqlens=cache_seqlens,
94
- conv_state_indices=conv_state_indices,
95
- )
96
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch210-cxx11-cu128-x86_64-linux/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
- from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
-
4
- __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:78531cef5f05968a528ae8bc7a5a348b2abad1b180ac90142dd7df2491cef608
3
- size 107169824
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _causal_conv1d_1b44a8e
3
- ops = torch.ops._causal_conv1d_1b44a8e
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_causal_conv1d_1b44a8e::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py DELETED
@@ -1,242 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
7
-
8
-
9
- class CausalConv1dFn(torch.autograd.Function):
10
- @staticmethod
11
- def forward(
12
- ctx,
13
- x,
14
- weight,
15
- bias=None,
16
- seq_idx=None,
17
- initial_states=None,
18
- return_final_states=False,
19
- final_states_out=None,
20
- activation=None,
21
- ):
22
- if activation not in [None, "silu", "swish"]:
23
- raise NotImplementedError("activation must be None, silu, or swish")
24
- if x.stride(2) != 1 and x.stride(1) != 1:
25
- x = x.contiguous()
26
- bias = bias.contiguous() if bias is not None else None
27
- if seq_idx is not None:
28
- assert (
29
- initial_states is None
30
- ), "initial_states must be None if seq_idx is not None"
31
- assert (
32
- not return_final_states
33
- ), "If seq_idx is not None, we don't return final_states_out"
34
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
35
- if initial_states is not None and (
36
- initial_states.stride(2) != 1 and initial_states.stride(1) != 1
37
- ):
38
- initial_states = initial_states.contiguous()
39
- if return_final_states:
40
- assert (
41
- x.stride(1) == 1
42
- ), "Only channel-last layout support returning final_states_out"
43
- if final_states_out is not None:
44
- assert (
45
- final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
46
- )
47
- else:
48
- batch, dim, seqlen = x.shape
49
- width = weight.shape[1]
50
- final_states_out = torch.empty(
51
- batch, width - 1, dim, device=x.device, dtype=x.dtype
52
- ).transpose(1, 2)
53
- else:
54
- final_states_out = None
55
- ctx.activation = activation in ["silu", "swish"]
56
- out = causal_conv1d_fwd_function(
57
- x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
58
- )
59
- ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
60
- ctx.return_final_states = return_final_states
61
- ctx.return_dinitial_states = (
62
- initial_states is not None and initial_states.requires_grad
63
- )
64
- return out if not return_final_states else (out, final_states_out)
65
-
66
- @staticmethod
67
- def backward(ctx, dout, *args):
68
- x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
69
- dfinal_states = args[0] if ctx.return_final_states else None
70
- if dout.stride(2) != 1 and dout.stride(1) != 1:
71
- dout = dout.contiguous()
72
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
73
- # backward of conv1d with the backward of chunk).
74
- # Here we just pass in None and dx will be allocated in the C++ code.
75
- dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
76
- x,
77
- weight,
78
- bias,
79
- dout,
80
- seq_idx,
81
- initial_states,
82
- dfinal_states,
83
- None,
84
- ctx.return_dinitial_states,
85
- ctx.activation,
86
- )
87
- return (
88
- dx,
89
- dweight,
90
- dbias if bias is not None else None,
91
- None,
92
- dinitial_states if initial_states is not None else None,
93
- None,
94
- None,
95
- None,
96
- )
97
-
98
-
99
- def causal_conv1d_fn(
100
- x,
101
- weight,
102
- bias=None,
103
- seq_idx=None,
104
- initial_states=None,
105
- return_final_states=False,
106
- final_states_out=None,
107
- activation=None,
108
- ):
109
- """
110
- x: (batch, dim, seqlen)
111
- weight: (dim, width)
112
- bias: (dim,)
113
- seq_idx: (batch, seqlen)
114
- initial_states: (batch, dim, width - 1)
115
- final_states_out: (batch, dim, width - 1), to be written to
116
- activation: either None or "silu" or "swish"
117
-
118
- out: (batch, dim, seqlen)
119
- """
120
- return CausalConv1dFn.apply(
121
- x,
122
- weight,
123
- bias,
124
- seq_idx,
125
- initial_states,
126
- return_final_states,
127
- final_states_out,
128
- activation,
129
- )
130
-
131
-
132
- def causal_conv1d_ref(
133
- x,
134
- weight,
135
- bias=None,
136
- initial_states=None,
137
- return_final_states=False,
138
- final_states_out=None,
139
- activation=None,
140
- ):
141
- """
142
- x: (batch, dim, seqlen)
143
- weight: (dim, width)
144
- bias: (dim,)
145
- initial_states: (batch, dim, width - 1)
146
- final_states_out: (batch, dim, width - 1)
147
-
148
- out: (batch, dim, seqlen)
149
- """
150
- if activation not in [None, "silu", "swish"]:
151
- raise NotImplementedError("activation must be None, silu, or swish")
152
- dtype_in = x.dtype
153
- x = x.to(weight.dtype)
154
- seqlen = x.shape[-1]
155
- dim, width = weight.shape
156
- if initial_states is None:
157
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
158
- else:
159
- x = torch.cat([initial_states, x], dim=-1)
160
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
161
- out = out[..., :seqlen]
162
- if return_final_states:
163
- final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
164
- dtype_in
165
- ) # (batch, dim, width - 1)
166
- if final_states_out is not None:
167
- final_states_out.copy_(final_states)
168
- else:
169
- final_states_out = final_states
170
- out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
171
- return out if not return_final_states else (out, final_states_out)
172
-
173
-
174
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
175
- """
176
- x: (batch, dim) or (batch, dim, seqlen)
177
- conv_state: (batch, dim, state_len), where state_len >= width - 1
178
- weight: (dim, width)
179
- bias: (dim,)
180
- cache_seqlens: (batch,), dtype int32.
181
- If not None, the conv_state is treated as a circular buffer.
182
- The conv_state will be updated by copying x to the conv_state starting at the index
183
- @cache_seqlens % state_len.
184
- conv_state_indices: (batch,), dtype int32
185
- If None, the conv_state is a larger tensor along the batch dim,
186
- and we are selecting the batch coords specified by conv_state_indices.
187
- Useful for a continuous batching scenario.
188
-
189
- out: (batch, dim) or (batch, dim, seqlen)
190
- """
191
- if activation not in [None, "silu", "swish"]:
192
- raise NotImplementedError("activation must be None, silu, or swish")
193
- activation = activation in ["silu", "swish"]
194
- unsqueeze = x.dim() == 2
195
- if unsqueeze:
196
- x = x.unsqueeze(-1)
197
- out = causal_conv1d_update_function(
198
- x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
199
- )
200
- if unsqueeze:
201
- out = out.squeeze(-1)
202
- return out
203
-
204
-
205
- def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
206
- """
207
- x: (batch, dim) or (batch, dim, seqlen)
208
- conv_state: (batch, dim, state_len), where state_len >= width - 1
209
- weight: (dim, width)
210
- bias: (dim,)
211
- cache_seqlens: (batch,), dtype int32.
212
- If not None, the conv_state is treated as a circular buffer.
213
- The conv_state will be updated by copying x to the conv_state starting at the index
214
- @cache_seqlens % state_len before performing the convolution.
215
-
216
- out: (batch, dim) or (batch, dim, seqlen)
217
- """
218
- if activation not in [None, "silu", "swish"]:
219
- raise NotImplementedError("activation must be None, silu, or swish")
220
- dtype_in = x.dtype
221
- unsqueeze = x.dim() == 2
222
- if unsqueeze:
223
- x = x.unsqueeze(-1)
224
- batch, dim, seqlen = x.shape
225
- width = weight.shape[1]
226
- state_len = conv_state.shape[-1]
227
- assert conv_state.shape == (batch, dim, state_len)
228
- assert weight.shape == (dim, width)
229
- if cache_seqlens is None:
230
- x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
231
- conv_state.copy_(x_new[:, :, -state_len:])
232
- else:
233
- width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
- width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
- x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
236
- copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
237
- copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
238
- conv_state.scatter_(2, copy_idx, x)
239
- out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
240
- if unsqueeze:
241
- out = out.squeeze(-1)
242
- return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py DELETED
@@ -1,86 +0,0 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- import triton
5
- import triton.language as tl
6
-
7
-
8
- @triton.jit
9
- def _causal_conv1d_varlen_states(
10
- X,
11
- CU_SEQLENS,
12
- STATES,
13
- state_len,
14
- dim,
15
- stride_x_seqlen, stride_x_dim,
16
- stride_states_batch, stride_states_seqlen, stride_states_dim,
17
- BLOCK_M: tl.constexpr,
18
- BLOCK_N: tl.constexpr
19
- ):
20
- batch_idx = tl.program_id(2)
21
- STATES += batch_idx * stride_states_batch
22
- end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
- start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
- rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
- cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
- x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
- mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
- other=0)
29
- rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
- tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
- x,
32
- mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
-
34
-
35
- def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
- """
37
- Forward pass only, does not support backward pass.
38
- Parameters:
39
- x: (total_tokens, dim)
40
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
- If some of those elements belong to a different sequence, the value of the states will be zero.
43
- Return:
44
- states: (batch, dim, state_len)
45
- """
46
- _, dim = x.shape
47
- batch = cu_seqlens.shape[0] - 1
48
- cu_seqlens = cu_seqlens.contiguous()
49
- states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
- BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
- BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
- grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
- with torch.cuda.device(x.device.index):
54
- _causal_conv1d_varlen_states[grid](
55
- x,
56
- cu_seqlens,
57
- states,
58
- state_len,
59
- dim,
60
- x.stride(0), x.stride(1),
61
- states.stride(0), states.stride(2), states.stride(1),
62
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
- )
64
- return states
65
-
66
-
67
- def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
- """
69
- Forward pass only, does not support backward pass.
70
- Parameters:
71
- x: (total_tokens, dim)
72
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
- If some of those elements belong to a different sequence, the value of the states will be zero.
75
- Return:
76
- states: (batch, dim, state_len)
77
- """
78
- _, dim = x.shape
79
- batch = cu_seqlens.shape[0] - 1
80
- cu_seqlens = cu_seqlens.contiguous()
81
- states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
- for i in range(batch):
83
- end_idx = cu_seqlens[i + 1]
84
- start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
- states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
- return states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py DELETED
@@ -1,96 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
-
5
- from ._ops import ops
6
-
7
- def causal_conv1d_fwd_function(
8
- x: torch.Tensor,
9
- weight: torch.Tensor,
10
- bias: torch.Tensor | None,
11
- seq_idx: torch.Tensor | None,
12
- initial_states: torch.Tensor | None,
13
- final_states_out: torch.Tensor | None,
14
- silu_activation: bool,
15
- ) -> torch.Tensor:
16
- out = torch.empty_like(x)
17
- ops.causal_conv1d_fwd(
18
- x=x,
19
- weight=weight,
20
- bias=bias,
21
- seq_idx=seq_idx,
22
- initial_states=initial_states,
23
- out=out,
24
- final_states_out=final_states_out,
25
- silu_activation=silu_activation,
26
- )
27
- return out
28
-
29
-
30
- def causal_conv1d_bwd_function(
31
- x: torch.Tensor,
32
- weight: torch.Tensor,
33
- bias: torch.Tensor | None,
34
- dout: torch.Tensor,
35
- seq_idx: torch.Tensor | None,
36
- initial_states: torch.Tensor | None,
37
- dfinal_states: torch.Tensor | None,
38
- dx: torch.Tensor | None,
39
- return_dinitial_states: torch.Tensor,
40
- silu_activation: bool,
41
- ) -> tuple[torch.Tensor | None]:
42
- batch_size, dim = x.size()[:2]
43
- width = weight.size(-1)
44
-
45
- if dx is None:
46
- dx = torch.empty_like(x)
47
- dweight = torch.zeros_like(weight, dtype=torch.float32)
48
- dbias = None
49
- if bias is not None:
50
- dbias = torch.zeros_like(bias, dtype=torch.float32)
51
- dinitial_states = None
52
- if return_dinitial_states:
53
- dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
54
-
55
- ops.causal_conv1d_bwd(
56
- x=x,
57
- weight=weight,
58
- bias=bias,
59
- dout=dout,
60
- seq_idx=seq_idx,
61
- initial_states=initial_states,
62
- dfinal_states=dfinal_states,
63
- dx=dx,
64
- dweight=dweight,
65
- dbias=dbias,
66
- dinitial_states=dinitial_states,
67
- silu_activation=silu_activation,
68
- )
69
-
70
- dweight = dweight.type_as(weight)
71
- if dbias is not None:
72
- dbias = dbias.type_as(bias)
73
- return dx, dweight, dbias, dinitial_states
74
-
75
-
76
- def causal_conv1d_update_function(
77
- x: torch.Tensor,
78
- conv_state: torch.Tensor,
79
- weight: torch.Tensor,
80
- bias: torch.Tensor | None,
81
- silu_activation: bool,
82
- cache_seqlens: torch.Tensor | None,
83
- conv_state_indices: torch.Tensor | None,
84
- ) -> torch.Tensor:
85
- out = torch.empty_like(x)
86
- ops.causal_conv1d_update(
87
- x=x,
88
- conv_state=conv_state,
89
- weight=weight,
90
- bias=bias,
91
- out=out,
92
- silu_activation=silu_activation,
93
- cache_seqlens=cache_seqlens,
94
- conv_state_indices=conv_state_indices,
95
- )
96
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch210-cxx11-cu130-x86_64-linux/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
- from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
-
4
- __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8746e8c1e94e2022fe638316ba9cf89489d45d0d92047cafe54e554297a2c701
3
- size 64618464
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _causal_conv1d_1b44a8e
3
- ops = torch.ops._causal_conv1d_1b44a8e
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_causal_conv1d_1b44a8e::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py DELETED
@@ -1,242 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
7
-
8
-
9
- class CausalConv1dFn(torch.autograd.Function):
10
- @staticmethod
11
- def forward(
12
- ctx,
13
- x,
14
- weight,
15
- bias=None,
16
- seq_idx=None,
17
- initial_states=None,
18
- return_final_states=False,
19
- final_states_out=None,
20
- activation=None,
21
- ):
22
- if activation not in [None, "silu", "swish"]:
23
- raise NotImplementedError("activation must be None, silu, or swish")
24
- if x.stride(2) != 1 and x.stride(1) != 1:
25
- x = x.contiguous()
26
- bias = bias.contiguous() if bias is not None else None
27
- if seq_idx is not None:
28
- assert (
29
- initial_states is None
30
- ), "initial_states must be None if seq_idx is not None"
31
- assert (
32
- not return_final_states
33
- ), "If seq_idx is not None, we don't return final_states_out"
34
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
35
- if initial_states is not None and (
36
- initial_states.stride(2) != 1 and initial_states.stride(1) != 1
37
- ):
38
- initial_states = initial_states.contiguous()
39
- if return_final_states:
40
- assert (
41
- x.stride(1) == 1
42
- ), "Only channel-last layout support returning final_states_out"
43
- if final_states_out is not None:
44
- assert (
45
- final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
46
- )
47
- else:
48
- batch, dim, seqlen = x.shape
49
- width = weight.shape[1]
50
- final_states_out = torch.empty(
51
- batch, width - 1, dim, device=x.device, dtype=x.dtype
52
- ).transpose(1, 2)
53
- else:
54
- final_states_out = None
55
- ctx.activation = activation in ["silu", "swish"]
56
- out = causal_conv1d_fwd_function(
57
- x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
58
- )
59
- ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
60
- ctx.return_final_states = return_final_states
61
- ctx.return_dinitial_states = (
62
- initial_states is not None and initial_states.requires_grad
63
- )
64
- return out if not return_final_states else (out, final_states_out)
65
-
66
- @staticmethod
67
- def backward(ctx, dout, *args):
68
- x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
69
- dfinal_states = args[0] if ctx.return_final_states else None
70
- if dout.stride(2) != 1 and dout.stride(1) != 1:
71
- dout = dout.contiguous()
72
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
73
- # backward of conv1d with the backward of chunk).
74
- # Here we just pass in None and dx will be allocated in the C++ code.
75
- dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
76
- x,
77
- weight,
78
- bias,
79
- dout,
80
- seq_idx,
81
- initial_states,
82
- dfinal_states,
83
- None,
84
- ctx.return_dinitial_states,
85
- ctx.activation,
86
- )
87
- return (
88
- dx,
89
- dweight,
90
- dbias if bias is not None else None,
91
- None,
92
- dinitial_states if initial_states is not None else None,
93
- None,
94
- None,
95
- None,
96
- )
97
-
98
-
99
- def causal_conv1d_fn(
100
- x,
101
- weight,
102
- bias=None,
103
- seq_idx=None,
104
- initial_states=None,
105
- return_final_states=False,
106
- final_states_out=None,
107
- activation=None,
108
- ):
109
- """
110
- x: (batch, dim, seqlen)
111
- weight: (dim, width)
112
- bias: (dim,)
113
- seq_idx: (batch, seqlen)
114
- initial_states: (batch, dim, width - 1)
115
- final_states_out: (batch, dim, width - 1), to be written to
116
- activation: either None or "silu" or "swish"
117
-
118
- out: (batch, dim, seqlen)
119
- """
120
- return CausalConv1dFn.apply(
121
- x,
122
- weight,
123
- bias,
124
- seq_idx,
125
- initial_states,
126
- return_final_states,
127
- final_states_out,
128
- activation,
129
- )
130
-
131
-
132
- def causal_conv1d_ref(
133
- x,
134
- weight,
135
- bias=None,
136
- initial_states=None,
137
- return_final_states=False,
138
- final_states_out=None,
139
- activation=None,
140
- ):
141
- """
142
- x: (batch, dim, seqlen)
143
- weight: (dim, width)
144
- bias: (dim,)
145
- initial_states: (batch, dim, width - 1)
146
- final_states_out: (batch, dim, width - 1)
147
-
148
- out: (batch, dim, seqlen)
149
- """
150
- if activation not in [None, "silu", "swish"]:
151
- raise NotImplementedError("activation must be None, silu, or swish")
152
- dtype_in = x.dtype
153
- x = x.to(weight.dtype)
154
- seqlen = x.shape[-1]
155
- dim, width = weight.shape
156
- if initial_states is None:
157
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
158
- else:
159
- x = torch.cat([initial_states, x], dim=-1)
160
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
161
- out = out[..., :seqlen]
162
- if return_final_states:
163
- final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
164
- dtype_in
165
- ) # (batch, dim, width - 1)
166
- if final_states_out is not None:
167
- final_states_out.copy_(final_states)
168
- else:
169
- final_states_out = final_states
170
- out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
171
- return out if not return_final_states else (out, final_states_out)
172
-
173
-
174
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
175
- """
176
- x: (batch, dim) or (batch, dim, seqlen)
177
- conv_state: (batch, dim, state_len), where state_len >= width - 1
178
- weight: (dim, width)
179
- bias: (dim,)
180
- cache_seqlens: (batch,), dtype int32.
181
- If not None, the conv_state is treated as a circular buffer.
182
- The conv_state will be updated by copying x to the conv_state starting at the index
183
- @cache_seqlens % state_len.
184
- conv_state_indices: (batch,), dtype int32
185
- If None, the conv_state is a larger tensor along the batch dim,
186
- and we are selecting the batch coords specified by conv_state_indices.
187
- Useful for a continuous batching scenario.
188
-
189
- out: (batch, dim) or (batch, dim, seqlen)
190
- """
191
- if activation not in [None, "silu", "swish"]:
192
- raise NotImplementedError("activation must be None, silu, or swish")
193
- activation = activation in ["silu", "swish"]
194
- unsqueeze = x.dim() == 2
195
- if unsqueeze:
196
- x = x.unsqueeze(-1)
197
- out = causal_conv1d_update_function(
198
- x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
199
- )
200
- if unsqueeze:
201
- out = out.squeeze(-1)
202
- return out
203
-
204
-
205
- def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
206
- """
207
- x: (batch, dim) or (batch, dim, seqlen)
208
- conv_state: (batch, dim, state_len), where state_len >= width - 1
209
- weight: (dim, width)
210
- bias: (dim,)
211
- cache_seqlens: (batch,), dtype int32.
212
- If not None, the conv_state is treated as a circular buffer.
213
- The conv_state will be updated by copying x to the conv_state starting at the index
214
- @cache_seqlens % state_len before performing the convolution.
215
-
216
- out: (batch, dim) or (batch, dim, seqlen)
217
- """
218
- if activation not in [None, "silu", "swish"]:
219
- raise NotImplementedError("activation must be None, silu, or swish")
220
- dtype_in = x.dtype
221
- unsqueeze = x.dim() == 2
222
- if unsqueeze:
223
- x = x.unsqueeze(-1)
224
- batch, dim, seqlen = x.shape
225
- width = weight.shape[1]
226
- state_len = conv_state.shape[-1]
227
- assert conv_state.shape == (batch, dim, state_len)
228
- assert weight.shape == (dim, width)
229
- if cache_seqlens is None:
230
- x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
231
- conv_state.copy_(x_new[:, :, -state_len:])
232
- else:
233
- width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
- width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
- x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
236
- copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
237
- copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
238
- conv_state.scatter_(2, copy_idx, x)
239
- out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
240
- if unsqueeze:
241
- out = out.squeeze(-1)
242
- return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py DELETED
@@ -1,86 +0,0 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- import triton
5
- import triton.language as tl
6
-
7
-
8
- @triton.jit
9
- def _causal_conv1d_varlen_states(
10
- X,
11
- CU_SEQLENS,
12
- STATES,
13
- state_len,
14
- dim,
15
- stride_x_seqlen, stride_x_dim,
16
- stride_states_batch, stride_states_seqlen, stride_states_dim,
17
- BLOCK_M: tl.constexpr,
18
- BLOCK_N: tl.constexpr
19
- ):
20
- batch_idx = tl.program_id(2)
21
- STATES += batch_idx * stride_states_batch
22
- end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
- start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
- rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
- cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
- x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
- mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
- other=0)
29
- rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
- tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
- x,
32
- mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
-
34
-
35
- def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
- """
37
- Forward pass only, does not support backward pass.
38
- Parameters:
39
- x: (total_tokens, dim)
40
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
- If some of those elements belong to a different sequence, the value of the states will be zero.
43
- Return:
44
- states: (batch, dim, state_len)
45
- """
46
- _, dim = x.shape
47
- batch = cu_seqlens.shape[0] - 1
48
- cu_seqlens = cu_seqlens.contiguous()
49
- states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
- BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
- BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
- grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
- with torch.cuda.device(x.device.index):
54
- _causal_conv1d_varlen_states[grid](
55
- x,
56
- cu_seqlens,
57
- states,
58
- state_len,
59
- dim,
60
- x.stride(0), x.stride(1),
61
- states.stride(0), states.stride(2), states.stride(1),
62
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
- )
64
- return states
65
-
66
-
67
- def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
- """
69
- Forward pass only, does not support backward pass.
70
- Parameters:
71
- x: (total_tokens, dim)
72
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
- If some of those elements belong to a different sequence, the value of the states will be zero.
75
- Return:
76
- states: (batch, dim, state_len)
77
- """
78
- _, dim = x.shape
79
- batch = cu_seqlens.shape[0] - 1
80
- cu_seqlens = cu_seqlens.contiguous()
81
- states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
- for i in range(batch):
83
- end_idx = cu_seqlens[i + 1]
84
- start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
- states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
- return states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py DELETED
@@ -1,96 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
-
5
- from ._ops import ops
6
-
7
- def causal_conv1d_fwd_function(
8
- x: torch.Tensor,
9
- weight: torch.Tensor,
10
- bias: torch.Tensor | None,
11
- seq_idx: torch.Tensor | None,
12
- initial_states: torch.Tensor | None,
13
- final_states_out: torch.Tensor | None,
14
- silu_activation: bool,
15
- ) -> torch.Tensor:
16
- out = torch.empty_like(x)
17
- ops.causal_conv1d_fwd(
18
- x=x,
19
- weight=weight,
20
- bias=bias,
21
- seq_idx=seq_idx,
22
- initial_states=initial_states,
23
- out=out,
24
- final_states_out=final_states_out,
25
- silu_activation=silu_activation,
26
- )
27
- return out
28
-
29
-
30
- def causal_conv1d_bwd_function(
31
- x: torch.Tensor,
32
- weight: torch.Tensor,
33
- bias: torch.Tensor | None,
34
- dout: torch.Tensor,
35
- seq_idx: torch.Tensor | None,
36
- initial_states: torch.Tensor | None,
37
- dfinal_states: torch.Tensor | None,
38
- dx: torch.Tensor | None,
39
- return_dinitial_states: torch.Tensor,
40
- silu_activation: bool,
41
- ) -> tuple[torch.Tensor | None]:
42
- batch_size, dim = x.size()[:2]
43
- width = weight.size(-1)
44
-
45
- if dx is None:
46
- dx = torch.empty_like(x)
47
- dweight = torch.zeros_like(weight, dtype=torch.float32)
48
- dbias = None
49
- if bias is not None:
50
- dbias = torch.zeros_like(bias, dtype=torch.float32)
51
- dinitial_states = None
52
- if return_dinitial_states:
53
- dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
54
-
55
- ops.causal_conv1d_bwd(
56
- x=x,
57
- weight=weight,
58
- bias=bias,
59
- dout=dout,
60
- seq_idx=seq_idx,
61
- initial_states=initial_states,
62
- dfinal_states=dfinal_states,
63
- dx=dx,
64
- dweight=dweight,
65
- dbias=dbias,
66
- dinitial_states=dinitial_states,
67
- silu_activation=silu_activation,
68
- )
69
-
70
- dweight = dweight.type_as(weight)
71
- if dbias is not None:
72
- dbias = dbias.type_as(bias)
73
- return dx, dweight, dbias, dinitial_states
74
-
75
-
76
- def causal_conv1d_update_function(
77
- x: torch.Tensor,
78
- conv_state: torch.Tensor,
79
- weight: torch.Tensor,
80
- bias: torch.Tensor | None,
81
- silu_activation: bool,
82
- cache_seqlens: torch.Tensor | None,
83
- conv_state_indices: torch.Tensor | None,
84
- ) -> torch.Tensor:
85
- out = torch.empty_like(x)
86
- ops.causal_conv1d_update(
87
- x=x,
88
- conv_state=conv_state,
89
- weight=weight,
90
- bias=bias,
91
- out=out,
92
- silu_activation=silu_activation,
93
- cache_seqlens=cache_seqlens,
94
- conv_state_indices=conv_state_indices,
95
- )
96
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch28-cxx11-cu126-x86_64-linux/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
- from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
-
4
- __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
 
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:49a73bdc1f6d9a32c2e107610f5ba22c2ca054a3efc1237a8291118af3191e7b
3
- size 80684768
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _causal_conv1d_1b44a8e
3
- ops = torch.ops._causal_conv1d_1b44a8e
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_causal_conv1d_1b44a8e::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py DELETED
@@ -1,242 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
7
-
8
-
9
- class CausalConv1dFn(torch.autograd.Function):
10
- @staticmethod
11
- def forward(
12
- ctx,
13
- x,
14
- weight,
15
- bias=None,
16
- seq_idx=None,
17
- initial_states=None,
18
- return_final_states=False,
19
- final_states_out=None,
20
- activation=None,
21
- ):
22
- if activation not in [None, "silu", "swish"]:
23
- raise NotImplementedError("activation must be None, silu, or swish")
24
- if x.stride(2) != 1 and x.stride(1) != 1:
25
- x = x.contiguous()
26
- bias = bias.contiguous() if bias is not None else None
27
- if seq_idx is not None:
28
- assert (
29
- initial_states is None
30
- ), "initial_states must be None if seq_idx is not None"
31
- assert (
32
- not return_final_states
33
- ), "If seq_idx is not None, we don't return final_states_out"
34
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
35
- if initial_states is not None and (
36
- initial_states.stride(2) != 1 and initial_states.stride(1) != 1
37
- ):
38
- initial_states = initial_states.contiguous()
39
- if return_final_states:
40
- assert (
41
- x.stride(1) == 1
42
- ), "Only channel-last layout support returning final_states_out"
43
- if final_states_out is not None:
44
- assert (
45
- final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
46
- )
47
- else:
48
- batch, dim, seqlen = x.shape
49
- width = weight.shape[1]
50
- final_states_out = torch.empty(
51
- batch, width - 1, dim, device=x.device, dtype=x.dtype
52
- ).transpose(1, 2)
53
- else:
54
- final_states_out = None
55
- ctx.activation = activation in ["silu", "swish"]
56
- out = causal_conv1d_fwd_function(
57
- x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
58
- )
59
- ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
60
- ctx.return_final_states = return_final_states
61
- ctx.return_dinitial_states = (
62
- initial_states is not None and initial_states.requires_grad
63
- )
64
- return out if not return_final_states else (out, final_states_out)
65
-
66
- @staticmethod
67
- def backward(ctx, dout, *args):
68
- x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
69
- dfinal_states = args[0] if ctx.return_final_states else None
70
- if dout.stride(2) != 1 and dout.stride(1) != 1:
71
- dout = dout.contiguous()
72
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
73
- # backward of conv1d with the backward of chunk).
74
- # Here we just pass in None and dx will be allocated in the C++ code.
75
- dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
76
- x,
77
- weight,
78
- bias,
79
- dout,
80
- seq_idx,
81
- initial_states,
82
- dfinal_states,
83
- None,
84
- ctx.return_dinitial_states,
85
- ctx.activation,
86
- )
87
- return (
88
- dx,
89
- dweight,
90
- dbias if bias is not None else None,
91
- None,
92
- dinitial_states if initial_states is not None else None,
93
- None,
94
- None,
95
- None,
96
- )
97
-
98
-
99
- def causal_conv1d_fn(
100
- x,
101
- weight,
102
- bias=None,
103
- seq_idx=None,
104
- initial_states=None,
105
- return_final_states=False,
106
- final_states_out=None,
107
- activation=None,
108
- ):
109
- """
110
- x: (batch, dim, seqlen)
111
- weight: (dim, width)
112
- bias: (dim,)
113
- seq_idx: (batch, seqlen)
114
- initial_states: (batch, dim, width - 1)
115
- final_states_out: (batch, dim, width - 1), to be written to
116
- activation: either None or "silu" or "swish"
117
-
118
- out: (batch, dim, seqlen)
119
- """
120
- return CausalConv1dFn.apply(
121
- x,
122
- weight,
123
- bias,
124
- seq_idx,
125
- initial_states,
126
- return_final_states,
127
- final_states_out,
128
- activation,
129
- )
130
-
131
-
132
- def causal_conv1d_ref(
133
- x,
134
- weight,
135
- bias=None,
136
- initial_states=None,
137
- return_final_states=False,
138
- final_states_out=None,
139
- activation=None,
140
- ):
141
- """
142
- x: (batch, dim, seqlen)
143
- weight: (dim, width)
144
- bias: (dim,)
145
- initial_states: (batch, dim, width - 1)
146
- final_states_out: (batch, dim, width - 1)
147
-
148
- out: (batch, dim, seqlen)
149
- """
150
- if activation not in [None, "silu", "swish"]:
151
- raise NotImplementedError("activation must be None, silu, or swish")
152
- dtype_in = x.dtype
153
- x = x.to(weight.dtype)
154
- seqlen = x.shape[-1]
155
- dim, width = weight.shape
156
- if initial_states is None:
157
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
158
- else:
159
- x = torch.cat([initial_states, x], dim=-1)
160
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
161
- out = out[..., :seqlen]
162
- if return_final_states:
163
- final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
164
- dtype_in
165
- ) # (batch, dim, width - 1)
166
- if final_states_out is not None:
167
- final_states_out.copy_(final_states)
168
- else:
169
- final_states_out = final_states
170
- out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
171
- return out if not return_final_states else (out, final_states_out)
172
-
173
-
174
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
175
- """
176
- x: (batch, dim) or (batch, dim, seqlen)
177
- conv_state: (batch, dim, state_len), where state_len >= width - 1
178
- weight: (dim, width)
179
- bias: (dim,)
180
- cache_seqlens: (batch,), dtype int32.
181
- If not None, the conv_state is treated as a circular buffer.
182
- The conv_state will be updated by copying x to the conv_state starting at the index
183
- @cache_seqlens % state_len.
184
- conv_state_indices: (batch,), dtype int32
185
- If None, the conv_state is a larger tensor along the batch dim,
186
- and we are selecting the batch coords specified by conv_state_indices.
187
- Useful for a continuous batching scenario.
188
-
189
- out: (batch, dim) or (batch, dim, seqlen)
190
- """
191
- if activation not in [None, "silu", "swish"]:
192
- raise NotImplementedError("activation must be None, silu, or swish")
193
- activation = activation in ["silu", "swish"]
194
- unsqueeze = x.dim() == 2
195
- if unsqueeze:
196
- x = x.unsqueeze(-1)
197
- out = causal_conv1d_update_function(
198
- x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
199
- )
200
- if unsqueeze:
201
- out = out.squeeze(-1)
202
- return out
203
-
204
-
205
- def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
206
- """
207
- x: (batch, dim) or (batch, dim, seqlen)
208
- conv_state: (batch, dim, state_len), where state_len >= width - 1
209
- weight: (dim, width)
210
- bias: (dim,)
211
- cache_seqlens: (batch,), dtype int32.
212
- If not None, the conv_state is treated as a circular buffer.
213
- The conv_state will be updated by copying x to the conv_state starting at the index
214
- @cache_seqlens % state_len before performing the convolution.
215
-
216
- out: (batch, dim) or (batch, dim, seqlen)
217
- """
218
- if activation not in [None, "silu", "swish"]:
219
- raise NotImplementedError("activation must be None, silu, or swish")
220
- dtype_in = x.dtype
221
- unsqueeze = x.dim() == 2
222
- if unsqueeze:
223
- x = x.unsqueeze(-1)
224
- batch, dim, seqlen = x.shape
225
- width = weight.shape[1]
226
- state_len = conv_state.shape[-1]
227
- assert conv_state.shape == (batch, dim, state_len)
228
- assert weight.shape == (dim, width)
229
- if cache_seqlens is None:
230
- x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
231
- conv_state.copy_(x_new[:, :, -state_len:])
232
- else:
233
- width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
- width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
- x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
236
- copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
237
- copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
238
- conv_state.scatter_(2, copy_idx, x)
239
- out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
240
- if unsqueeze:
241
- out = out.squeeze(-1)
242
- return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py DELETED
@@ -1,86 +0,0 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- import triton
5
- import triton.language as tl
6
-
7
-
8
- @triton.jit
9
- def _causal_conv1d_varlen_states(
10
- X,
11
- CU_SEQLENS,
12
- STATES,
13
- state_len,
14
- dim,
15
- stride_x_seqlen, stride_x_dim,
16
- stride_states_batch, stride_states_seqlen, stride_states_dim,
17
- BLOCK_M: tl.constexpr,
18
- BLOCK_N: tl.constexpr
19
- ):
20
- batch_idx = tl.program_id(2)
21
- STATES += batch_idx * stride_states_batch
22
- end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
- start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
- rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
- cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
- x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
- mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
- other=0)
29
- rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
- tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
- x,
32
- mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
-
34
-
35
- def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
- """
37
- Forward pass only, does not support backward pass.
38
- Parameters:
39
- x: (total_tokens, dim)
40
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
- If some of those elements belong to a different sequence, the value of the states will be zero.
43
- Return:
44
- states: (batch, dim, state_len)
45
- """
46
- _, dim = x.shape
47
- batch = cu_seqlens.shape[0] - 1
48
- cu_seqlens = cu_seqlens.contiguous()
49
- states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
- BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
- BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
- grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
- with torch.cuda.device(x.device.index):
54
- _causal_conv1d_varlen_states[grid](
55
- x,
56
- cu_seqlens,
57
- states,
58
- state_len,
59
- dim,
60
- x.stride(0), x.stride(1),
61
- states.stride(0), states.stride(2), states.stride(1),
62
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
- )
64
- return states
65
-
66
-
67
- def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
- """
69
- Forward pass only, does not support backward pass.
70
- Parameters:
71
- x: (total_tokens, dim)
72
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
- If some of those elements belong to a different sequence, the value of the states will be zero.
75
- Return:
76
- states: (batch, dim, state_len)
77
- """
78
- _, dim = x.shape
79
- batch = cu_seqlens.shape[0] - 1
80
- cu_seqlens = cu_seqlens.contiguous()
81
- states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
- for i in range(batch):
83
- end_idx = cu_seqlens[i + 1]
84
- start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
- states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
- return states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py DELETED
@@ -1,96 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
-
5
- from ._ops import ops
6
-
7
- def causal_conv1d_fwd_function(
8
- x: torch.Tensor,
9
- weight: torch.Tensor,
10
- bias: torch.Tensor | None,
11
- seq_idx: torch.Tensor | None,
12
- initial_states: torch.Tensor | None,
13
- final_states_out: torch.Tensor | None,
14
- silu_activation: bool,
15
- ) -> torch.Tensor:
16
- out = torch.empty_like(x)
17
- ops.causal_conv1d_fwd(
18
- x=x,
19
- weight=weight,
20
- bias=bias,
21
- seq_idx=seq_idx,
22
- initial_states=initial_states,
23
- out=out,
24
- final_states_out=final_states_out,
25
- silu_activation=silu_activation,
26
- )
27
- return out
28
-
29
-
30
- def causal_conv1d_bwd_function(
31
- x: torch.Tensor,
32
- weight: torch.Tensor,
33
- bias: torch.Tensor | None,
34
- dout: torch.Tensor,
35
- seq_idx: torch.Tensor | None,
36
- initial_states: torch.Tensor | None,
37
- dfinal_states: torch.Tensor | None,
38
- dx: torch.Tensor | None,
39
- return_dinitial_states: torch.Tensor,
40
- silu_activation: bool,
41
- ) -> tuple[torch.Tensor | None]:
42
- batch_size, dim = x.size()[:2]
43
- width = weight.size(-1)
44
-
45
- if dx is None:
46
- dx = torch.empty_like(x)
47
- dweight = torch.zeros_like(weight, dtype=torch.float32)
48
- dbias = None
49
- if bias is not None:
50
- dbias = torch.zeros_like(bias, dtype=torch.float32)
51
- dinitial_states = None
52
- if return_dinitial_states:
53
- dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
54
-
55
- ops.causal_conv1d_bwd(
56
- x=x,
57
- weight=weight,
58
- bias=bias,
59
- dout=dout,
60
- seq_idx=seq_idx,
61
- initial_states=initial_states,
62
- dfinal_states=dfinal_states,
63
- dx=dx,
64
- dweight=dweight,
65
- dbias=dbias,
66
- dinitial_states=dinitial_states,
67
- silu_activation=silu_activation,
68
- )
69
-
70
- dweight = dweight.type_as(weight)
71
- if dbias is not None:
72
- dbias = dbias.type_as(bias)
73
- return dx, dweight, dbias, dinitial_states
74
-
75
-
76
- def causal_conv1d_update_function(
77
- x: torch.Tensor,
78
- conv_state: torch.Tensor,
79
- weight: torch.Tensor,
80
- bias: torch.Tensor | None,
81
- silu_activation: bool,
82
- cache_seqlens: torch.Tensor | None,
83
- conv_state_indices: torch.Tensor | None,
84
- ) -> torch.Tensor:
85
- out = torch.empty_like(x)
86
- ops.causal_conv1d_update(
87
- x=x,
88
- conv_state=conv_state,
89
- weight=weight,
90
- bias=bias,
91
- out=out,
92
- silu_activation=silu_activation,
93
- cache_seqlens=cache_seqlens,
94
- conv_state_indices=conv_state_indices,
95
- )
96
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch28-cxx11-cu128-x86_64-linux/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
- from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
-
4
- __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
 
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:43ea19b486dc11d1eb780e7c1c4944ad27d27713ab41b8824b14add98c5eb645
3
- size 107168432
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _causal_conv1d_1b44a8e
3
- ops = torch.ops._causal_conv1d_1b44a8e
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_causal_conv1d_1b44a8e::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py DELETED
@@ -1,242 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
7
-
8
-
9
- class CausalConv1dFn(torch.autograd.Function):
10
- @staticmethod
11
- def forward(
12
- ctx,
13
- x,
14
- weight,
15
- bias=None,
16
- seq_idx=None,
17
- initial_states=None,
18
- return_final_states=False,
19
- final_states_out=None,
20
- activation=None,
21
- ):
22
- if activation not in [None, "silu", "swish"]:
23
- raise NotImplementedError("activation must be None, silu, or swish")
24
- if x.stride(2) != 1 and x.stride(1) != 1:
25
- x = x.contiguous()
26
- bias = bias.contiguous() if bias is not None else None
27
- if seq_idx is not None:
28
- assert (
29
- initial_states is None
30
- ), "initial_states must be None if seq_idx is not None"
31
- assert (
32
- not return_final_states
33
- ), "If seq_idx is not None, we don't return final_states_out"
34
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
35
- if initial_states is not None and (
36
- initial_states.stride(2) != 1 and initial_states.stride(1) != 1
37
- ):
38
- initial_states = initial_states.contiguous()
39
- if return_final_states:
40
- assert (
41
- x.stride(1) == 1
42
- ), "Only channel-last layout support returning final_states_out"
43
- if final_states_out is not None:
44
- assert (
45
- final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
46
- )
47
- else:
48
- batch, dim, seqlen = x.shape
49
- width = weight.shape[1]
50
- final_states_out = torch.empty(
51
- batch, width - 1, dim, device=x.device, dtype=x.dtype
52
- ).transpose(1, 2)
53
- else:
54
- final_states_out = None
55
- ctx.activation = activation in ["silu", "swish"]
56
- out = causal_conv1d_fwd_function(
57
- x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
58
- )
59
- ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
60
- ctx.return_final_states = return_final_states
61
- ctx.return_dinitial_states = (
62
- initial_states is not None and initial_states.requires_grad
63
- )
64
- return out if not return_final_states else (out, final_states_out)
65
-
66
- @staticmethod
67
- def backward(ctx, dout, *args):
68
- x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
69
- dfinal_states = args[0] if ctx.return_final_states else None
70
- if dout.stride(2) != 1 and dout.stride(1) != 1:
71
- dout = dout.contiguous()
72
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
73
- # backward of conv1d with the backward of chunk).
74
- # Here we just pass in None and dx will be allocated in the C++ code.
75
- dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
76
- x,
77
- weight,
78
- bias,
79
- dout,
80
- seq_idx,
81
- initial_states,
82
- dfinal_states,
83
- None,
84
- ctx.return_dinitial_states,
85
- ctx.activation,
86
- )
87
- return (
88
- dx,
89
- dweight,
90
- dbias if bias is not None else None,
91
- None,
92
- dinitial_states if initial_states is not None else None,
93
- None,
94
- None,
95
- None,
96
- )
97
-
98
-
99
- def causal_conv1d_fn(
100
- x,
101
- weight,
102
- bias=None,
103
- seq_idx=None,
104
- initial_states=None,
105
- return_final_states=False,
106
- final_states_out=None,
107
- activation=None,
108
- ):
109
- """
110
- x: (batch, dim, seqlen)
111
- weight: (dim, width)
112
- bias: (dim,)
113
- seq_idx: (batch, seqlen)
114
- initial_states: (batch, dim, width - 1)
115
- final_states_out: (batch, dim, width - 1), to be written to
116
- activation: either None or "silu" or "swish"
117
-
118
- out: (batch, dim, seqlen)
119
- """
120
- return CausalConv1dFn.apply(
121
- x,
122
- weight,
123
- bias,
124
- seq_idx,
125
- initial_states,
126
- return_final_states,
127
- final_states_out,
128
- activation,
129
- )
130
-
131
-
132
- def causal_conv1d_ref(
133
- x,
134
- weight,
135
- bias=None,
136
- initial_states=None,
137
- return_final_states=False,
138
- final_states_out=None,
139
- activation=None,
140
- ):
141
- """
142
- x: (batch, dim, seqlen)
143
- weight: (dim, width)
144
- bias: (dim,)
145
- initial_states: (batch, dim, width - 1)
146
- final_states_out: (batch, dim, width - 1)
147
-
148
- out: (batch, dim, seqlen)
149
- """
150
- if activation not in [None, "silu", "swish"]:
151
- raise NotImplementedError("activation must be None, silu, or swish")
152
- dtype_in = x.dtype
153
- x = x.to(weight.dtype)
154
- seqlen = x.shape[-1]
155
- dim, width = weight.shape
156
- if initial_states is None:
157
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
158
- else:
159
- x = torch.cat([initial_states, x], dim=-1)
160
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
161
- out = out[..., :seqlen]
162
- if return_final_states:
163
- final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
164
- dtype_in
165
- ) # (batch, dim, width - 1)
166
- if final_states_out is not None:
167
- final_states_out.copy_(final_states)
168
- else:
169
- final_states_out = final_states
170
- out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
171
- return out if not return_final_states else (out, final_states_out)
172
-
173
-
174
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
175
- """
176
- x: (batch, dim) or (batch, dim, seqlen)
177
- conv_state: (batch, dim, state_len), where state_len >= width - 1
178
- weight: (dim, width)
179
- bias: (dim,)
180
- cache_seqlens: (batch,), dtype int32.
181
- If not None, the conv_state is treated as a circular buffer.
182
- The conv_state will be updated by copying x to the conv_state starting at the index
183
- @cache_seqlens % state_len.
184
- conv_state_indices: (batch,), dtype int32
185
- If None, the conv_state is a larger tensor along the batch dim,
186
- and we are selecting the batch coords specified by conv_state_indices.
187
- Useful for a continuous batching scenario.
188
-
189
- out: (batch, dim) or (batch, dim, seqlen)
190
- """
191
- if activation not in [None, "silu", "swish"]:
192
- raise NotImplementedError("activation must be None, silu, or swish")
193
- activation = activation in ["silu", "swish"]
194
- unsqueeze = x.dim() == 2
195
- if unsqueeze:
196
- x = x.unsqueeze(-1)
197
- out = causal_conv1d_update_function(
198
- x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
199
- )
200
- if unsqueeze:
201
- out = out.squeeze(-1)
202
- return out
203
-
204
-
205
- def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
206
- """
207
- x: (batch, dim) or (batch, dim, seqlen)
208
- conv_state: (batch, dim, state_len), where state_len >= width - 1
209
- weight: (dim, width)
210
- bias: (dim,)
211
- cache_seqlens: (batch,), dtype int32.
212
- If not None, the conv_state is treated as a circular buffer.
213
- The conv_state will be updated by copying x to the conv_state starting at the index
214
- @cache_seqlens % state_len before performing the convolution.
215
-
216
- out: (batch, dim) or (batch, dim, seqlen)
217
- """
218
- if activation not in [None, "silu", "swish"]:
219
- raise NotImplementedError("activation must be None, silu, or swish")
220
- dtype_in = x.dtype
221
- unsqueeze = x.dim() == 2
222
- if unsqueeze:
223
- x = x.unsqueeze(-1)
224
- batch, dim, seqlen = x.shape
225
- width = weight.shape[1]
226
- state_len = conv_state.shape[-1]
227
- assert conv_state.shape == (batch, dim, state_len)
228
- assert weight.shape == (dim, width)
229
- if cache_seqlens is None:
230
- x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
231
- conv_state.copy_(x_new[:, :, -state_len:])
232
- else:
233
- width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
- width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
- x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
236
- copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
237
- copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
238
- conv_state.scatter_(2, copy_idx, x)
239
- out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
240
- if unsqueeze:
241
- out = out.squeeze(-1)
242
- return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py DELETED
@@ -1,86 +0,0 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- import triton
5
- import triton.language as tl
6
-
7
-
8
- @triton.jit
9
- def _causal_conv1d_varlen_states(
10
- X,
11
- CU_SEQLENS,
12
- STATES,
13
- state_len,
14
- dim,
15
- stride_x_seqlen, stride_x_dim,
16
- stride_states_batch, stride_states_seqlen, stride_states_dim,
17
- BLOCK_M: tl.constexpr,
18
- BLOCK_N: tl.constexpr
19
- ):
20
- batch_idx = tl.program_id(2)
21
- STATES += batch_idx * stride_states_batch
22
- end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
- start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
- rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
- cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
- x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
- mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
- other=0)
29
- rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
- tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
- x,
32
- mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
-
34
-
35
- def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
- """
37
- Forward pass only, does not support backward pass.
38
- Parameters:
39
- x: (total_tokens, dim)
40
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
- If some of those elements belong to a different sequence, the value of the states will be zero.
43
- Return:
44
- states: (batch, dim, state_len)
45
- """
46
- _, dim = x.shape
47
- batch = cu_seqlens.shape[0] - 1
48
- cu_seqlens = cu_seqlens.contiguous()
49
- states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
- BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
- BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
- grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
- with torch.cuda.device(x.device.index):
54
- _causal_conv1d_varlen_states[grid](
55
- x,
56
- cu_seqlens,
57
- states,
58
- state_len,
59
- dim,
60
- x.stride(0), x.stride(1),
61
- states.stride(0), states.stride(2), states.stride(1),
62
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
- )
64
- return states
65
-
66
-
67
- def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
- """
69
- Forward pass only, does not support backward pass.
70
- Parameters:
71
- x: (total_tokens, dim)
72
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
- If some of those elements belong to a different sequence, the value of the states will be zero.
75
- Return:
76
- states: (batch, dim, state_len)
77
- """
78
- _, dim = x.shape
79
- batch = cu_seqlens.shape[0] - 1
80
- cu_seqlens = cu_seqlens.contiguous()
81
- states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
- for i in range(batch):
83
- end_idx = cu_seqlens[i + 1]
84
- start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
- states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
- return states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py DELETED
@@ -1,96 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
-
5
- from ._ops import ops
6
-
7
- def causal_conv1d_fwd_function(
8
- x: torch.Tensor,
9
- weight: torch.Tensor,
10
- bias: torch.Tensor | None,
11
- seq_idx: torch.Tensor | None,
12
- initial_states: torch.Tensor | None,
13
- final_states_out: torch.Tensor | None,
14
- silu_activation: bool,
15
- ) -> torch.Tensor:
16
- out = torch.empty_like(x)
17
- ops.causal_conv1d_fwd(
18
- x=x,
19
- weight=weight,
20
- bias=bias,
21
- seq_idx=seq_idx,
22
- initial_states=initial_states,
23
- out=out,
24
- final_states_out=final_states_out,
25
- silu_activation=silu_activation,
26
- )
27
- return out
28
-
29
-
30
- def causal_conv1d_bwd_function(
31
- x: torch.Tensor,
32
- weight: torch.Tensor,
33
- bias: torch.Tensor | None,
34
- dout: torch.Tensor,
35
- seq_idx: torch.Tensor | None,
36
- initial_states: torch.Tensor | None,
37
- dfinal_states: torch.Tensor | None,
38
- dx: torch.Tensor | None,
39
- return_dinitial_states: torch.Tensor,
40
- silu_activation: bool,
41
- ) -> tuple[torch.Tensor | None]:
42
- batch_size, dim = x.size()[:2]
43
- width = weight.size(-1)
44
-
45
- if dx is None:
46
- dx = torch.empty_like(x)
47
- dweight = torch.zeros_like(weight, dtype=torch.float32)
48
- dbias = None
49
- if bias is not None:
50
- dbias = torch.zeros_like(bias, dtype=torch.float32)
51
- dinitial_states = None
52
- if return_dinitial_states:
53
- dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
54
-
55
- ops.causal_conv1d_bwd(
56
- x=x,
57
- weight=weight,
58
- bias=bias,
59
- dout=dout,
60
- seq_idx=seq_idx,
61
- initial_states=initial_states,
62
- dfinal_states=dfinal_states,
63
- dx=dx,
64
- dweight=dweight,
65
- dbias=dbias,
66
- dinitial_states=dinitial_states,
67
- silu_activation=silu_activation,
68
- )
69
-
70
- dweight = dweight.type_as(weight)
71
- if dbias is not None:
72
- dbias = dbias.type_as(bias)
73
- return dx, dweight, dbias, dinitial_states
74
-
75
-
76
- def causal_conv1d_update_function(
77
- x: torch.Tensor,
78
- conv_state: torch.Tensor,
79
- weight: torch.Tensor,
80
- bias: torch.Tensor | None,
81
- silu_activation: bool,
82
- cache_seqlens: torch.Tensor | None,
83
- conv_state_indices: torch.Tensor | None,
84
- ) -> torch.Tensor:
85
- out = torch.empty_like(x)
86
- ops.causal_conv1d_update(
87
- x=x,
88
- conv_state=conv_state,
89
- weight=weight,
90
- bias=bias,
91
- out=out,
92
- silu_activation=silu_activation,
93
- cache_seqlens=cache_seqlens,
94
- conv_state_indices=conv_state_indices,
95
- )
96
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch28-cxx11-cu129-x86_64-linux/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
- from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
-
4
- __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
 
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1a95ffd016cdfe6f401e0e495b0083a12395ffe82b3888c510b86f4a58dfe068
3
- size 115140584
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _causal_conv1d_1b44a8e
3
- ops = torch.ops._causal_conv1d_1b44a8e
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_causal_conv1d_1b44a8e::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py DELETED
@@ -1,242 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
7
-
8
-
9
- class CausalConv1dFn(torch.autograd.Function):
10
- @staticmethod
11
- def forward(
12
- ctx,
13
- x,
14
- weight,
15
- bias=None,
16
- seq_idx=None,
17
- initial_states=None,
18
- return_final_states=False,
19
- final_states_out=None,
20
- activation=None,
21
- ):
22
- if activation not in [None, "silu", "swish"]:
23
- raise NotImplementedError("activation must be None, silu, or swish")
24
- if x.stride(2) != 1 and x.stride(1) != 1:
25
- x = x.contiguous()
26
- bias = bias.contiguous() if bias is not None else None
27
- if seq_idx is not None:
28
- assert (
29
- initial_states is None
30
- ), "initial_states must be None if seq_idx is not None"
31
- assert (
32
- not return_final_states
33
- ), "If seq_idx is not None, we don't return final_states_out"
34
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
35
- if initial_states is not None and (
36
- initial_states.stride(2) != 1 and initial_states.stride(1) != 1
37
- ):
38
- initial_states = initial_states.contiguous()
39
- if return_final_states:
40
- assert (
41
- x.stride(1) == 1
42
- ), "Only channel-last layout support returning final_states_out"
43
- if final_states_out is not None:
44
- assert (
45
- final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
46
- )
47
- else:
48
- batch, dim, seqlen = x.shape
49
- width = weight.shape[1]
50
- final_states_out = torch.empty(
51
- batch, width - 1, dim, device=x.device, dtype=x.dtype
52
- ).transpose(1, 2)
53
- else:
54
- final_states_out = None
55
- ctx.activation = activation in ["silu", "swish"]
56
- out = causal_conv1d_fwd_function(
57
- x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
58
- )
59
- ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
60
- ctx.return_final_states = return_final_states
61
- ctx.return_dinitial_states = (
62
- initial_states is not None and initial_states.requires_grad
63
- )
64
- return out if not return_final_states else (out, final_states_out)
65
-
66
- @staticmethod
67
- def backward(ctx, dout, *args):
68
- x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
69
- dfinal_states = args[0] if ctx.return_final_states else None
70
- if dout.stride(2) != 1 and dout.stride(1) != 1:
71
- dout = dout.contiguous()
72
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
73
- # backward of conv1d with the backward of chunk).
74
- # Here we just pass in None and dx will be allocated in the C++ code.
75
- dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
76
- x,
77
- weight,
78
- bias,
79
- dout,
80
- seq_idx,
81
- initial_states,
82
- dfinal_states,
83
- None,
84
- ctx.return_dinitial_states,
85
- ctx.activation,
86
- )
87
- return (
88
- dx,
89
- dweight,
90
- dbias if bias is not None else None,
91
- None,
92
- dinitial_states if initial_states is not None else None,
93
- None,
94
- None,
95
- None,
96
- )
97
-
98
-
99
- def causal_conv1d_fn(
100
- x,
101
- weight,
102
- bias=None,
103
- seq_idx=None,
104
- initial_states=None,
105
- return_final_states=False,
106
- final_states_out=None,
107
- activation=None,
108
- ):
109
- """
110
- x: (batch, dim, seqlen)
111
- weight: (dim, width)
112
- bias: (dim,)
113
- seq_idx: (batch, seqlen)
114
- initial_states: (batch, dim, width - 1)
115
- final_states_out: (batch, dim, width - 1), to be written to
116
- activation: either None or "silu" or "swish"
117
-
118
- out: (batch, dim, seqlen)
119
- """
120
- return CausalConv1dFn.apply(
121
- x,
122
- weight,
123
- bias,
124
- seq_idx,
125
- initial_states,
126
- return_final_states,
127
- final_states_out,
128
- activation,
129
- )
130
-
131
-
132
- def causal_conv1d_ref(
133
- x,
134
- weight,
135
- bias=None,
136
- initial_states=None,
137
- return_final_states=False,
138
- final_states_out=None,
139
- activation=None,
140
- ):
141
- """
142
- x: (batch, dim, seqlen)
143
- weight: (dim, width)
144
- bias: (dim,)
145
- initial_states: (batch, dim, width - 1)
146
- final_states_out: (batch, dim, width - 1)
147
-
148
- out: (batch, dim, seqlen)
149
- """
150
- if activation not in [None, "silu", "swish"]:
151
- raise NotImplementedError("activation must be None, silu, or swish")
152
- dtype_in = x.dtype
153
- x = x.to(weight.dtype)
154
- seqlen = x.shape[-1]
155
- dim, width = weight.shape
156
- if initial_states is None:
157
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
158
- else:
159
- x = torch.cat([initial_states, x], dim=-1)
160
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
161
- out = out[..., :seqlen]
162
- if return_final_states:
163
- final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
164
- dtype_in
165
- ) # (batch, dim, width - 1)
166
- if final_states_out is not None:
167
- final_states_out.copy_(final_states)
168
- else:
169
- final_states_out = final_states
170
- out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
171
- return out if not return_final_states else (out, final_states_out)
172
-
173
-
174
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
175
- """
176
- x: (batch, dim) or (batch, dim, seqlen)
177
- conv_state: (batch, dim, state_len), where state_len >= width - 1
178
- weight: (dim, width)
179
- bias: (dim,)
180
- cache_seqlens: (batch,), dtype int32.
181
- If not None, the conv_state is treated as a circular buffer.
182
- The conv_state will be updated by copying x to the conv_state starting at the index
183
- @cache_seqlens % state_len.
184
- conv_state_indices: (batch,), dtype int32
185
- If None, the conv_state is a larger tensor along the batch dim,
186
- and we are selecting the batch coords specified by conv_state_indices.
187
- Useful for a continuous batching scenario.
188
-
189
- out: (batch, dim) or (batch, dim, seqlen)
190
- """
191
- if activation not in [None, "silu", "swish"]:
192
- raise NotImplementedError("activation must be None, silu, or swish")
193
- activation = activation in ["silu", "swish"]
194
- unsqueeze = x.dim() == 2
195
- if unsqueeze:
196
- x = x.unsqueeze(-1)
197
- out = causal_conv1d_update_function(
198
- x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
199
- )
200
- if unsqueeze:
201
- out = out.squeeze(-1)
202
- return out
203
-
204
-
205
- def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
206
- """
207
- x: (batch, dim) or (batch, dim, seqlen)
208
- conv_state: (batch, dim, state_len), where state_len >= width - 1
209
- weight: (dim, width)
210
- bias: (dim,)
211
- cache_seqlens: (batch,), dtype int32.
212
- If not None, the conv_state is treated as a circular buffer.
213
- The conv_state will be updated by copying x to the conv_state starting at the index
214
- @cache_seqlens % state_len before performing the convolution.
215
-
216
- out: (batch, dim) or (batch, dim, seqlen)
217
- """
218
- if activation not in [None, "silu", "swish"]:
219
- raise NotImplementedError("activation must be None, silu, or swish")
220
- dtype_in = x.dtype
221
- unsqueeze = x.dim() == 2
222
- if unsqueeze:
223
- x = x.unsqueeze(-1)
224
- batch, dim, seqlen = x.shape
225
- width = weight.shape[1]
226
- state_len = conv_state.shape[-1]
227
- assert conv_state.shape == (batch, dim, state_len)
228
- assert weight.shape == (dim, width)
229
- if cache_seqlens is None:
230
- x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
231
- conv_state.copy_(x_new[:, :, -state_len:])
232
- else:
233
- width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
- width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
- x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
236
- copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
237
- copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
238
- conv_state.scatter_(2, copy_idx, x)
239
- out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
240
- if unsqueeze:
241
- out = out.squeeze(-1)
242
- return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py DELETED
@@ -1,86 +0,0 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- import triton
5
- import triton.language as tl
6
-
7
-
8
- @triton.jit
9
- def _causal_conv1d_varlen_states(
10
- X,
11
- CU_SEQLENS,
12
- STATES,
13
- state_len,
14
- dim,
15
- stride_x_seqlen, stride_x_dim,
16
- stride_states_batch, stride_states_seqlen, stride_states_dim,
17
- BLOCK_M: tl.constexpr,
18
- BLOCK_N: tl.constexpr
19
- ):
20
- batch_idx = tl.program_id(2)
21
- STATES += batch_idx * stride_states_batch
22
- end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
- start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
- rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
- cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
- x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
- mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
- other=0)
29
- rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
- tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
- x,
32
- mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
-
34
-
35
- def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
- """
37
- Forward pass only, does not support backward pass.
38
- Parameters:
39
- x: (total_tokens, dim)
40
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
- If some of those elements belong to a different sequence, the value of the states will be zero.
43
- Return:
44
- states: (batch, dim, state_len)
45
- """
46
- _, dim = x.shape
47
- batch = cu_seqlens.shape[0] - 1
48
- cu_seqlens = cu_seqlens.contiguous()
49
- states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
- BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
- BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
- grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
- with torch.cuda.device(x.device.index):
54
- _causal_conv1d_varlen_states[grid](
55
- x,
56
- cu_seqlens,
57
- states,
58
- state_len,
59
- dim,
60
- x.stride(0), x.stride(1),
61
- states.stride(0), states.stride(2), states.stride(1),
62
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
- )
64
- return states
65
-
66
-
67
- def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
- """
69
- Forward pass only, does not support backward pass.
70
- Parameters:
71
- x: (total_tokens, dim)
72
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
- If some of those elements belong to a different sequence, the value of the states will be zero.
75
- Return:
76
- states: (batch, dim, state_len)
77
- """
78
- _, dim = x.shape
79
- batch = cu_seqlens.shape[0] - 1
80
- cu_seqlens = cu_seqlens.contiguous()
81
- states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
- for i in range(batch):
83
- end_idx = cu_seqlens[i + 1]
84
- start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
- states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
- return states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py DELETED
@@ -1,96 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
-
5
- from ._ops import ops
6
-
7
- def causal_conv1d_fwd_function(
8
- x: torch.Tensor,
9
- weight: torch.Tensor,
10
- bias: torch.Tensor | None,
11
- seq_idx: torch.Tensor | None,
12
- initial_states: torch.Tensor | None,
13
- final_states_out: torch.Tensor | None,
14
- silu_activation: bool,
15
- ) -> torch.Tensor:
16
- out = torch.empty_like(x)
17
- ops.causal_conv1d_fwd(
18
- x=x,
19
- weight=weight,
20
- bias=bias,
21
- seq_idx=seq_idx,
22
- initial_states=initial_states,
23
- out=out,
24
- final_states_out=final_states_out,
25
- silu_activation=silu_activation,
26
- )
27
- return out
28
-
29
-
30
- def causal_conv1d_bwd_function(
31
- x: torch.Tensor,
32
- weight: torch.Tensor,
33
- bias: torch.Tensor | None,
34
- dout: torch.Tensor,
35
- seq_idx: torch.Tensor | None,
36
- initial_states: torch.Tensor | None,
37
- dfinal_states: torch.Tensor | None,
38
- dx: torch.Tensor | None,
39
- return_dinitial_states: torch.Tensor,
40
- silu_activation: bool,
41
- ) -> tuple[torch.Tensor | None]:
42
- batch_size, dim = x.size()[:2]
43
- width = weight.size(-1)
44
-
45
- if dx is None:
46
- dx = torch.empty_like(x)
47
- dweight = torch.zeros_like(weight, dtype=torch.float32)
48
- dbias = None
49
- if bias is not None:
50
- dbias = torch.zeros_like(bias, dtype=torch.float32)
51
- dinitial_states = None
52
- if return_dinitial_states:
53
- dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
54
-
55
- ops.causal_conv1d_bwd(
56
- x=x,
57
- weight=weight,
58
- bias=bias,
59
- dout=dout,
60
- seq_idx=seq_idx,
61
- initial_states=initial_states,
62
- dfinal_states=dfinal_states,
63
- dx=dx,
64
- dweight=dweight,
65
- dbias=dbias,
66
- dinitial_states=dinitial_states,
67
- silu_activation=silu_activation,
68
- )
69
-
70
- dweight = dweight.type_as(weight)
71
- if dbias is not None:
72
- dbias = dbias.type_as(bias)
73
- return dx, dweight, dbias, dinitial_states
74
-
75
-
76
- def causal_conv1d_update_function(
77
- x: torch.Tensor,
78
- conv_state: torch.Tensor,
79
- weight: torch.Tensor,
80
- bias: torch.Tensor | None,
81
- silu_activation: bool,
82
- cache_seqlens: torch.Tensor | None,
83
- conv_state_indices: torch.Tensor | None,
84
- ) -> torch.Tensor:
85
- out = torch.empty_like(x)
86
- ops.causal_conv1d_update(
87
- x=x,
88
- conv_state=conv_state,
89
- weight=weight,
90
- bias=bias,
91
- out=out,
92
- silu_activation=silu_activation,
93
- cache_seqlens=cache_seqlens,
94
- conv_state_indices=conv_state_indices,
95
- )
96
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch29-cxx11-cu126-x86_64-linux/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
- from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
-
4
- __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
 
 
 
 
 
build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_1b44a8e.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1652a695f38a12463ece0e84007e34575177c678c2432e97e1510064ea6b627a
3
- size 80684856