File size: 7,256 Bytes
be761d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
from typing import Optional, Tuple
import torch
import causal_conv1d_cuda
# Causal Conv1D Forward Function
@torch.library.custom_op(
"mamba_causal_conv1d::causal_conv1d_fwd",
mutates_args=(),
device_types="cuda",
)
def causal_conv1d_fwd(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
) -> torch.Tensor:
# Ensure activation is valid
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
# Ensure x is contiguous
if x.stride(2) != 1 and x.stride(1) != 1:
x = x.contiguous()
# Make bias and seq_idx contiguous if they exist
bias = bias.contiguous() if bias is not None else None
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
# Translate activation to bool for custom CUDA kernel
use_activation = activation in ["silu", "swish"]
# Call custom CUDA kernel for forward pass
out = causal_conv1d_cuda.causal_conv1d_fwd(
x, weight, bias, seq_idx, None, None, use_activation
)
return out
# Register a fake forward pass for tracing
@causal_conv1d_fwd.register_fake
def _causal_conv1d_fwd_fake(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
) -> torch.Tensor:
torch._check(x.shape[-2] == weight.shape[0])
return torch.empty_like(x)
# Causal Conv1D Backward Function
@torch.library.custom_op(
"mamba_causal_conv1d::causal_conv1d_bwd",
mutates_args=(),
device_types="cuda",
)
def causal_conv1d_bwd(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
dout: torch.Tensor,
seq_idx: Optional[torch.Tensor],
activation: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Ensure dout is contiguous
if dout.stride(2) != 1 and dout.stride(1) != 1:
dout = dout.contiguous()
# Call custom CUDA kernel for backward pass
dx, dweight, dbias, _ = causal_conv1d_cuda.causal_conv1d_bwd(
x, weight, bias, dout, seq_idx, None, None, None, False, activation
)
# Handle optional bias gradient
dbias = dbias if bias is not None else torch.empty((0,), device=dout.device)
return dx, dweight, dbias
# Register a fake backward pass for tracing
@causal_conv1d_bwd.register_fake
def _causal_conv1d_bwd_fake(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
dout: torch.Tensor,
seq_idx: Optional[torch.Tensor],
activation: bool,
):
return (
torch.empty_like(x),
torch.empty_like(weight),
torch.empty_like(bias) if bias is not None else None,
)
# Setup context for autograd
def causal_conv1d_setup_context(ctx, inputs, output):
x, weight, bias, seq_idx, activation = inputs
ctx.activation = activation in ["silu", "swish"]
ctx.save_for_backward(x, weight, bias, seq_idx)
# Bridge for backward pass in autograd
def causal_conv1d_bwd_bridge(ctx, dout):
x, weight, bias, seq_idx = ctx.saved_tensors
dx, dweight, dbias = causal_conv1d_bwd(x, weight, bias, dout, seq_idx, ctx.activation)
# Handle None return values
dbias = dbias if bias is not None else None
return dx, dweight, dbias, None, None
# Register custom autograd function
torch.library.register_autograd(
"mamba_causal_conv1d::causal_conv1d_fwd",
causal_conv1d_bwd_bridge,
setup_context=causal_conv1d_setup_context,
)
# Define a higher-level function to invoke the custom op
def causal_conv1d_fn(x, weight, bias=None, seq_idx=None, activation=None):
return causal_conv1d_fwd(x, weight, bias, seq_idx, activation)
@torch.library.custom_op(
"mamba_causal_conv1d::causal_conv1d_update",
mutates_args=(),
device_types="cuda",
)
def causal_conv1d_update_fwd(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
cache_seqlens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state starting at the index
@cache_seqlens % state_len.
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
activation = activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
out = causal_conv1d_cuda.causal_conv1d_update(
x, conv_state, weight, bias, activation, cache_seqlens
)
if unsqueeze:
out = out.squeeze(-1)
return out
@causal_conv1d_update_fwd.register_fake
def _causal_conv1d_update_fwd(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
cache_seqlens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(x)
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
return causal_conv1d_update_fwd(x, conv_state, weight, bias, activation, cache_seqlens)
# Test the implementation
if __name__ == "__main__":
from causal_conv1d import causal_conv1d_fn as causal_conv1d_fn_ref
torch.manual_seed(0)
x = torch.randn(8, 32, 16, device="cuda", requires_grad=True)
weight = torch.randn(32, 3, device="cuda", requires_grad=True)
bias = None#torch.randn(32, device="cuda", requires_grad=True)
# Test the forward and backward pass
print("Custom Implementation")
out = causal_conv1d_fn(x, weight, bias, activation="silu")
out.sum().backward()
print(out.min(), out.max(), out.mean(), out.std())
print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
# Try compiling the function using torch.compile
x.grad.zero_(), weight.grad.zero_()
compiled_conv1d = torch.compile(causal_conv1d_fn)
print(compiled_conv1d)
# Run the compiled function
print("Compiled Implementation")
out = compiled_conv1d(x, weight, bias, activation="silu")
out.sum().backward()
print(out.min(), out.max(), out.mean(), out.std())
print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
print("Reference Implementation")
x.grad.zero_(), weight.grad.zero_()
out = causal_conv1d_fn_ref(x, weight, bias, activation="silu")
out.sum().backward()
print(out.min(), out.max(), out.mean(), out.std())
print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())
|