File size: 7,256 Bytes
be761d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
from typing import Optional, Tuple
import torch
import causal_conv1d_cuda

# Causal Conv1D Forward Function
@torch.library.custom_op(
    "mamba_causal_conv1d::causal_conv1d_fwd",
    mutates_args=(),
    device_types="cuda",
)
def causal_conv1d_fwd(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    seq_idx: Optional[torch.Tensor] = None,
    activation: Optional[str] = None,
) -> torch.Tensor:
    # Ensure activation is valid
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")

    # Ensure x is contiguous
    if x.stride(2) != 1 and x.stride(1) != 1:
        x = x.contiguous()

    # Make bias and seq_idx contiguous if they exist
    bias = bias.contiguous() if bias is not None else None
    seq_idx = seq_idx.contiguous() if seq_idx is not None else None

    # Translate activation to bool for custom CUDA kernel
    use_activation = activation in ["silu", "swish"]

    # Call custom CUDA kernel for forward pass
    out = causal_conv1d_cuda.causal_conv1d_fwd(
        x, weight, bias, seq_idx, None, None, use_activation
    )
    return out

# Register a fake forward pass for tracing
@causal_conv1d_fwd.register_fake
def _causal_conv1d_fwd_fake(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    seq_idx: Optional[torch.Tensor] = None,
    activation: Optional[str] = None,
) -> torch.Tensor:
    torch._check(x.shape[-2] == weight.shape[0])
    return torch.empty_like(x)

# Causal Conv1D Backward Function
@torch.library.custom_op(
    "mamba_causal_conv1d::causal_conv1d_bwd", 
    mutates_args=(),
    device_types="cuda",
)
def causal_conv1d_bwd(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor],
    dout: torch.Tensor,
    seq_idx: Optional[torch.Tensor],
    activation: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # Ensure dout is contiguous
    if dout.stride(2) != 1 and dout.stride(1) != 1:
        dout = dout.contiguous()

    # Call custom CUDA kernel for backward pass
    dx, dweight, dbias, _ = causal_conv1d_cuda.causal_conv1d_bwd(
        x, weight, bias, dout, seq_idx, None, None, None, False, activation
    )

    # Handle optional bias gradient
    dbias = dbias if bias is not None else torch.empty((0,), device=dout.device)
    
    return dx, dweight, dbias

# Register a fake backward pass for tracing
@causal_conv1d_bwd.register_fake
def _causal_conv1d_bwd_fake(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor],
    dout: torch.Tensor,
    seq_idx: Optional[torch.Tensor],
    activation: bool,
):
    return (
        torch.empty_like(x),
        torch.empty_like(weight),
        torch.empty_like(bias) if bias is not None else None,
    )

# Setup context for autograd
def causal_conv1d_setup_context(ctx, inputs, output):
    x, weight, bias, seq_idx, activation = inputs
    ctx.activation = activation in ["silu", "swish"]
    ctx.save_for_backward(x, weight, bias, seq_idx)

# Bridge for backward pass in autograd
def causal_conv1d_bwd_bridge(ctx, dout):
    x, weight, bias, seq_idx = ctx.saved_tensors
    dx, dweight, dbias = causal_conv1d_bwd(x, weight, bias, dout, seq_idx, ctx.activation)
    
    # Handle None return values
    dbias = dbias if bias is not None else None
    return dx, dweight, dbias, None, None

# Register custom autograd function
torch.library.register_autograd(
    "mamba_causal_conv1d::causal_conv1d_fwd",
    causal_conv1d_bwd_bridge,
    setup_context=causal_conv1d_setup_context,
)

# Define a higher-level function to invoke the custom op
def causal_conv1d_fn(x, weight, bias=None, seq_idx=None, activation=None):
    return causal_conv1d_fwd(x, weight, bias, seq_idx, activation)


@torch.library.custom_op(
    "mamba_causal_conv1d::causal_conv1d_update",
    mutates_args=(),
    device_types="cuda",
)
def causal_conv1d_update_fwd(
    x: torch.Tensor, 
    conv_state: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    activation: Optional[str] = None,
    cache_seqlens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    x: (batch, dim) or (batch, dim, seqlen)
    conv_state: (batch, dim, state_len), where state_len >= width - 1
    weight: (dim, width)
    bias: (dim,)
    cache_seqlens: (batch,), dtype int32.
        If not None, the conv_state is treated as a circular buffer.
        The conv_state will be updated by copying x to the conv_state starting at the index
        @cache_seqlens % state_len.

    out: (batch, dim) or (batch, dim, seqlen)
    """
    if activation not in [None, "silu", "swish"]:
        raise NotImplementedError("activation must be None, silu, or swish")
    activation = activation in ["silu", "swish"]
    unsqueeze = x.dim() == 2
    if unsqueeze:
        x = x.unsqueeze(-1)
    out = causal_conv1d_cuda.causal_conv1d_update(
        x, conv_state, weight, bias, activation, cache_seqlens
    )
    if unsqueeze:
        out = out.squeeze(-1)
    return out

@causal_conv1d_update_fwd.register_fake
def _causal_conv1d_update_fwd(
    x: torch.Tensor, 
    conv_state: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    activation: Optional[str] = None,
    cache_seqlens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return torch.empty_like(x)

def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
    return causal_conv1d_update_fwd(x, conv_state, weight, bias, activation, cache_seqlens)

# Test the implementation
if __name__ == "__main__":
    from causal_conv1d import causal_conv1d_fn as causal_conv1d_fn_ref

    torch.manual_seed(0)

    x = torch.randn(8, 32, 16, device="cuda", requires_grad=True)
    weight = torch.randn(32, 3, device="cuda", requires_grad=True)
    bias = None#torch.randn(32, device="cuda", requires_grad=True)

    # Test the forward and backward pass
    print("Custom Implementation")
    out = causal_conv1d_fn(x, weight, bias, activation="silu")
    out.sum().backward()

    print(out.min(), out.max(), out.mean(), out.std())
    print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
    print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())

    # Try compiling the function using torch.compile
    x.grad.zero_(), weight.grad.zero_()
    compiled_conv1d = torch.compile(causal_conv1d_fn)
    print(compiled_conv1d)

    # Run the compiled function
    print("Compiled Implementation")
    out = compiled_conv1d(x, weight, bias, activation="silu")
    out.sum().backward()

    print(out.min(), out.max(), out.mean(), out.std())
    print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
    print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())

    print("Reference Implementation")
    x.grad.zero_(), weight.grad.zero_()
    out = causal_conv1d_fn_ref(x, weight, bias, activation="silu")
    out.sum().backward()

    print(out.min(), out.max(), out.mean(), out.std())
    print(x.grad.min(), x.grad.max(), x.grad.mean(), x.grad.std())
    print(weight.grad.min(), weight.grad.max(), weight.grad.mean(), weight.grad.std())