|
|
|
|
| import torch
|
| import torch.nn.functional as F
|
| from torch.cuda.amp import custom_bwd, custom_fwd
|
|
|
| from einops import rearrange, repeat
|
|
|
| try:
|
| from causal_conv1d import causal_conv1d_fn
|
| import causal_conv1d_cuda
|
| except ImportError:
|
| causal_conv1d_fn = None
|
| causal_conv1d_cuda = None
|
|
|
| import selective_scan_cuda
|
|
|
|
|
| class SelectiveScanFn(torch.autograd.Function):
|
|
|
| @staticmethod
|
| def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
| return_last_state=False):
|
| if u.stride(-1) != 1:
|
| u = u.contiguous()
|
| if delta.stride(-1) != 1:
|
| delta = delta.contiguous()
|
| if D is not None:
|
| D = D.contiguous()
|
| if B.stride(-1) != 1:
|
| B = B.contiguous()
|
| if C.stride(-1) != 1:
|
| C = C.contiguous()
|
| if z is not None and z.stride(-1) != 1:
|
| z = z.contiguous()
|
| if B.dim() == 3:
|
| B = rearrange(B, "b dstate l -> b 1 dstate l")
|
| ctx.squeeze_B = True
|
| if C.dim() == 3:
|
| C = rearrange(C, "b dstate l -> b 1 dstate l")
|
| ctx.squeeze_C = True
|
| out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
|
| ctx.delta_softplus = delta_softplus
|
| ctx.has_z = z is not None
|
| last_state = x[:, :, -1, 1::2]
|
| if not ctx.has_z:
|
| ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
| return out if not return_last_state else (out, last_state)
|
| else:
|
| ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
| out_z = rest[0]
|
| return out_z if not return_last_state else (out_z, last_state)
|
|
|
| @staticmethod
|
| def backward(ctx, dout, *args):
|
| if not ctx.has_z:
|
| u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
| z = None
|
| out = None
|
| else:
|
| u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
| if dout.stride(-1) != 1:
|
| dout = dout.contiguous()
|
|
|
|
|
|
|
| du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
| u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
|
| False
|
| )
|
| dz = rest[0] if ctx.has_z else None
|
| dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
| dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
| return (du, ddelta, dA, dB, dC,
|
| dD if D is not None else None,
|
| dz,
|
| ddelta_bias if delta_bias is not None else None,
|
| None,
|
| None)
|
|
|
|
|
| def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
| return_last_state=False):
|
| """if return_last_state is True, returns (out, last_state)
|
| last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
| not considered in the backward pass.
|
| """
|
| return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
|
|
|
|
| def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
| return_last_state=False):
|
| """
|
| u: r(B D L)
|
| delta: r(B D L)
|
| A: c(D N) or r(D N)
|
| B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
| D: r(D)
|
| z: r(B D L)
|
| delta_bias: r(D), fp32
|
|
|
| out: r(B D L)
|
| last_state (optional): r(B D dstate) or c(B D dstate)
|
| """
|
| dtype_in = u.dtype
|
| u = u.float()
|
| delta = delta.float()
|
| if delta_bias is not None:
|
| delta = delta + delta_bias[..., None].float()
|
| if delta_softplus:
|
| delta = F.softplus(delta)
|
| batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
| is_variable_B = B.dim() >= 3
|
| is_variable_C = C.dim() >= 3
|
| if A.is_complex():
|
| if is_variable_B:
|
| B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
| if is_variable_C:
|
| C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
| else:
|
| B = B.float()
|
| C = C.float()
|
| x = A.new_zeros((batch, dim, dstate))
|
| ys = []
|
| deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
| if not is_variable_B:
|
| deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
| else:
|
| if B.dim() == 3:
|
| deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
| else:
|
| B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
| deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
| if is_variable_C and C.dim() == 4:
|
| C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
| last_state = None
|
| for i in range(u.shape[2]):
|
| x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
| if not is_variable_C:
|
| y = torch.einsum('bdn,dn->bd', x, C)
|
| else:
|
| if C.dim() == 3:
|
| y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
| else:
|
| y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
| if i == u.shape[2] - 1:
|
| last_state = x
|
| if y.is_complex():
|
| y = y.real * 2
|
| ys.append(y)
|
| y = torch.stack(ys, dim=2)
|
| out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
| if z is not None:
|
| out = out * F.silu(z)
|
| out = out.to(dtype=dtype_in)
|
| return out if not return_last_state else (out, last_state)
|
|
|
|
|
| class MambaInnerFn(torch.autograd.Function):
|
|
|
| @staticmethod
|
| @custom_fwd
|
| def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| out_proj_weight, out_proj_bias,
|
| A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
| C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
| """
|
| xz: (batch, dim, seqlen)
|
| """
|
| assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| assert checkpoint_lvl in [0, 1]
|
| L = xz.shape[-1]
|
| delta_rank = delta_proj_weight.shape[1]
|
| d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| if torch.is_autocast_enabled():
|
| x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
| if out_proj_bias is not None else None)
|
| if xz.stride(-1) != 1:
|
| xz = xz.contiguous()
|
| conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
| x, z = xz.chunk(2, dim=1)
|
| conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
|
|
| conv1d_out = causal_conv1d_fn(x, conv1d_weight, conv1d_bias, activation=None)
|
|
|
|
|
|
|
| x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)
|
| delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
| ctx.is_variable_B = B is None
|
| ctx.is_variable_C = C is None
|
| ctx.B_proj_bias_is_None = B_proj_bias is None
|
| ctx.C_proj_bias_is_None = C_proj_bias is None
|
| if B is None:
|
| B = x_dbl[:, delta_rank:delta_rank + d_state]
|
| if B_proj_bias is not None:
|
| B = B + B_proj_bias.to(dtype=B.dtype)
|
| if not A.is_complex():
|
|
|
| B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| else:
|
| B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
| else:
|
| if B.stride(-1) != 1:
|
| B = B.contiguous()
|
| if C is None:
|
| C = x_dbl[:, -d_state:]
|
| if C_proj_bias is not None:
|
| C = C + C_proj_bias.to(dtype=C.dtype)
|
| if not A.is_complex():
|
|
|
| C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| else:
|
| C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
| else:
|
| if C.stride(-1) != 1:
|
| C = C.contiguous()
|
| if D is not None:
|
| D = D.contiguous()
|
| out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
| conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| )
|
| ctx.delta_softplus = delta_softplus
|
| ctx.out_proj_bias_is_None = out_proj_bias is None
|
| ctx.checkpoint_lvl = checkpoint_lvl
|
| if checkpoint_lvl >= 1:
|
| conv1d_out, delta = None, None
|
| ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
| delta_proj_weight, out_proj_weight, conv1d_out, delta,
|
| A, B, C, D, delta_bias, scan_intermediates, out)
|
| return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
|
|
| @staticmethod
|
| @custom_bwd
|
| def backward(ctx, dout):
|
|
|
| assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
| (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
|
| conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
|
| L = xz.shape[-1]
|
| delta_rank = delta_proj_weight.shape[1]
|
| d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| x, z = xz.chunk(2, dim=1)
|
| if dout.stride(-1) != 1:
|
| dout = dout.contiguous()
|
| if ctx.checkpoint_lvl == 1:
|
| conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| x, conv1d_weight, conv1d_bias, None, None, None, True
|
| )
|
| delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
| "d (b l) -> b d l", l = L)
|
|
|
|
|
| dxz = torch.empty_like(xz)
|
| dx, dz = dxz.chunk(2, dim=1)
|
| dout = rearrange(dout, "b l e -> e (b l)")
|
| dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
| dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
| conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
|
| ctx.delta_softplus,
|
| True
|
| )
|
| dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
| dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
| dD = dD if D is not None else None
|
| dx_dbl = torch.empty_like(x_dbl)
|
| dB_proj_bias = None
|
| if ctx.is_variable_B:
|
| if not A.is_complex():
|
| dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
| else:
|
| dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
| dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
| dx_dbl[:, delta_rank:delta_rank + d_state] = dB
|
| dB = None
|
| dC_proj_bias = None
|
| if ctx.is_variable_C:
|
| if not A.is_complex():
|
| dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
| else:
|
| dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
| dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
| dx_dbl[:, -d_state:] = dC
|
| dC = None
|
| ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
| ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
| dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
| dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
| dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
| dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
| dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
|
|
|
|
| dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
| x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
|
| )
|
| dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
| dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
| return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
| dout_proj_weight, dout_proj_bias,
|
| dA, dB, dC, dD,
|
| ddelta_bias if delta_bias is not None else None,
|
| dB_proj_bias, dC_proj_bias, None)
|
|
|
| class MambaInnerFnNoOutProj(torch.autograd.Function):
|
|
|
| @staticmethod
|
| @custom_fwd
|
| def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
| C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
| """
|
| xz: (batch, dim, seqlen)
|
| """
|
| assert checkpoint_lvl in [0, 1]
|
| L = xz.shape[-1]
|
| delta_rank = delta_proj_weight.shape[1]
|
| d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| if torch.is_autocast_enabled():
|
| x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| if xz.stride(-1) != 1:
|
| xz = xz.contiguous()
|
| conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
| x, z = xz.chunk(2, dim=1)
|
| conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
|
|
|
|
| conv1d_out = causal_conv1d_fn(x, conv1d_weight, conv1d_bias, activation=None)
|
|
|
|
|
|
|
| x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)
|
| delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
| ctx.is_variable_B = B is None
|
| ctx.is_variable_C = C is None
|
| ctx.B_proj_bias_is_None = B_proj_bias is None
|
| ctx.C_proj_bias_is_None = C_proj_bias is None
|
| if B is None:
|
| B = x_dbl[:, delta_rank:delta_rank + d_state]
|
| if B_proj_bias is not None:
|
| B = B + B_proj_bias.to(dtype=B.dtype)
|
| if not A.is_complex():
|
|
|
| B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| else:
|
| B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
| else:
|
| if B.stride(-1) != 1:
|
| B = B.contiguous()
|
| if C is None:
|
| C = x_dbl[:, -d_state:]
|
| if C_proj_bias is not None:
|
| C = C + C_proj_bias.to(dtype=C.dtype)
|
| if not A.is_complex():
|
|
|
| C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| else:
|
| C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
| else:
|
| if C.stride(-1) != 1:
|
| C = C.contiguous()
|
| if D is not None:
|
| D = D.contiguous()
|
| out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
| conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| )
|
| ctx.delta_softplus = delta_softplus
|
| ctx.checkpoint_lvl = checkpoint_lvl
|
| if checkpoint_lvl >= 1:
|
| conv1d_out, delta = None, None
|
| ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
| delta_proj_weight, conv1d_out, delta,
|
| A, B, C, D, delta_bias, scan_intermediates, out)
|
|
|
| return out_z
|
|
|
| @staticmethod
|
| @custom_bwd
|
| def backward(ctx, dout):
|
|
|
| (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight,
|
| conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
|
| L = xz.shape[-1]
|
| delta_rank = delta_proj_weight.shape[1]
|
| d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| x, z = xz.chunk(2, dim=1)
|
| if dout.stride(-1) != 1:
|
| dout = dout.contiguous()
|
| if ctx.checkpoint_lvl == 1:
|
| conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, None, None, True)
|
| delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
| "d (b l) -> b d l", l = L)
|
|
|
|
|
| dxz = torch.empty_like(xz)
|
| dx, dz = dxz.chunk(2, dim=1)
|
|
|
| dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
| conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz,
|
| ctx.delta_softplus,
|
| True
|
| )
|
| dD = dD if D is not None else None
|
| dx_dbl = torch.empty_like(x_dbl)
|
| dB_proj_bias = None
|
| if ctx.is_variable_B:
|
| if not A.is_complex():
|
| dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
| else:
|
| dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
| dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
| dx_dbl[:, delta_rank:delta_rank + d_state] = dB
|
| dB = None
|
| dC_proj_bias = None
|
| if ctx.is_variable_C:
|
| if not A.is_complex():
|
| dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
| else:
|
| dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
| dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
| dx_dbl[:, -d_state:] = dC
|
| dC = None
|
| ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
| ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
| dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
| dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
| dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
| dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
| dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
|
|
|
|
| dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
| x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
|
| )
|
| dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
| dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
| return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
| dA, dB, dC, dD,
|
| ddelta_bias if delta_bias is not None else None,
|
| dB_proj_bias, dC_proj_bias, None)
|
|
|
| class BiMambaInnerFn(torch.autograd.Function):
|
|
|
| @staticmethod
|
| @custom_fwd
|
| def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| out_proj_weight, out_proj_bias,
|
| A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
| C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
| """
|
| xz: (batch, dim, seqlen)
|
| """
|
| assert checkpoint_lvl in [0, 1]
|
| L = xz.shape[-1]
|
| delta_rank = delta_proj_weight.shape[1]
|
| d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| if torch.is_autocast_enabled():
|
| x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
| if out_proj_bias is not None else None)
|
| if xz.stride(-1) != 1:
|
| xz = xz.contiguous()
|
| conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
| x, z = xz.chunk(2, dim=1)
|
| conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
| conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, None, None, True)
|
|
|
|
|
|
|
| x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)
|
| delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
| ctx.is_variable_B = B is None
|
| ctx.is_variable_C = C is None
|
| ctx.B_proj_bias_is_None = B_proj_bias is None
|
| ctx.C_proj_bias_is_None = C_proj_bias is None
|
| if B is None:
|
| B = x_dbl[:, delta_rank:delta_rank + d_state]
|
| if B_proj_bias is not None:
|
| B = B + B_proj_bias.to(dtype=B.dtype)
|
| if not A.is_complex():
|
|
|
| B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| else:
|
| B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
| else:
|
| if B.stride(-1) != 1:
|
| B = B.contiguous()
|
| if C is None:
|
| C = x_dbl[:, -d_state:]
|
| if C_proj_bias is not None:
|
| C = C + C_proj_bias.to(dtype=C.dtype)
|
| if not A.is_complex():
|
|
|
| C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
| else:
|
| C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
| else:
|
| if C.stride(-1) != 1:
|
| C = C.contiguous()
|
| if D is not None:
|
| D = D.contiguous()
|
| out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd(
|
| conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
| )
|
| assert not A_b.is_complex(), "A should not be complex!!"
|
| out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd(
|
| conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus,
|
| )
|
|
|
| out_z = out_z_f + out_z_b.flip([-1])
|
|
|
| ctx.delta_softplus = delta_softplus
|
| ctx.out_proj_bias_is_None = out_proj_bias is None
|
| ctx.checkpoint_lvl = checkpoint_lvl
|
| if checkpoint_lvl >= 1:
|
| conv1d_out, delta = None, None
|
| ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
| delta_proj_weight, out_proj_weight, conv1d_out, delta,
|
| A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b)
|
| return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
|
|
| @staticmethod
|
| @custom_bwd
|
| def backward(ctx, dout):
|
|
|
| (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
|
| conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) = ctx.saved_tensors
|
| L = xz.shape[-1]
|
| delta_rank = delta_proj_weight.shape[1]
|
| d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| x, z = xz.chunk(2, dim=1)
|
| if dout.stride(-1) != 1:
|
| dout = dout.contiguous()
|
| if ctx.checkpoint_lvl == 1:
|
| conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
|
| delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
| "d (b l) -> b d l", l = L)
|
|
|
|
|
| dxz = torch.empty_like(xz)
|
| dx, dz = dxz.chunk(2, dim=1)
|
| dout = rearrange(dout, "b l e -> e (b l)")
|
| dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
| dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z_f = selective_scan_cuda.bwd(
|
| conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates_f, out_f, dz,
|
| ctx.delta_softplus,
|
| True
|
| )
|
|
|
| dz_b = torch.empty_like(dz)
|
| dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd(
|
| conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b,
|
| ctx.delta_softplus,
|
| True
|
| )
|
|
|
| dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1])
|
| ddelta = ddelta + ddelta_f_b.flip([-1])
|
| dB = dB + dB_f_b.flip([-1])
|
| dC = dC + dC_f_b.flip([-1])
|
| dD = dD + dD_b
|
| ddelta_bias = ddelta_bias + ddelta_bias_b
|
| dz = dz + dz_b.flip([-1])
|
| out_z = out_z_f + out_z_b.flip([-1])
|
|
|
| dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
| dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
| dD = dD if D is not None else None
|
| dx_dbl = torch.empty_like(x_dbl)
|
| dB_proj_bias = None
|
| if ctx.is_variable_B:
|
| if not A.is_complex():
|
| dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
| else:
|
| dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
| dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
| dx_dbl[:, delta_rank:delta_rank + d_state] = dB
|
| dB = None
|
| dC_proj_bias = None
|
| if ctx.is_variable_C:
|
| if not A.is_complex():
|
| dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
| else:
|
| dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
| dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
| dx_dbl[:, -d_state:] = dC
|
| dC = None
|
| ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
| ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
| dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
| dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
| dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
| dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
| dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
|
|
|
|
| dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
|
| x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
|
| )
|
| dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
| dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
| return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
| dout_proj_weight, dout_proj_bias,
|
| dA, dA_b, dB, dC, dD,
|
| ddelta_bias if delta_bias is not None else None,
|
| dB_proj_bias, dC_proj_bias, None)
|
|
|
|
|
| def bimamba_inner_fn(
|
| xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| out_proj_weight, out_proj_bias,
|
| A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
| C_proj_bias=None, delta_softplus=True
|
| ):
|
| return BiMambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| out_proj_weight, out_proj_bias,
|
| A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
|
|
|
|
| def mamba_inner_fn(
|
| xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| out_proj_weight, out_proj_bias,
|
| A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
| C_proj_bias=None, delta_softplus=True
|
| ):
|
| return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| out_proj_weight, out_proj_bias,
|
| A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
|
|
| def mamba_inner_fn_no_out_proj(
|
| xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
| C_proj_bias=None, delta_softplus=True
|
| ):
|
| return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
|
|
|
|
| def mamba_inner_ref(
|
| xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| out_proj_weight, out_proj_bias,
|
| A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
| C_proj_bias=None, delta_softplus=True
|
| ):
|
| assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
|
| L = xz.shape[-1]
|
| delta_rank = delta_proj_weight.shape[1]
|
| d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| x, z = xz.chunk(2, dim=1)
|
| x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
|
|
|
|
|
|
|
| x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)
|
| delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
| delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
| if B is None:
|
| B = x_dbl[:, delta_rank:delta_rank + d_state]
|
| if B_proj_bias is not None:
|
| B = B + B_proj_bias.to(dtype=B.dtype)
|
| if not A.is_complex():
|
| B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| else:
|
| B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
| if C is None:
|
| C = x_dbl[:, -d_state:]
|
| if C_proj_bias is not None:
|
| C = C + C_proj_bias.to(dtype=C.dtype)
|
| if not A.is_complex():
|
| C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| else:
|
| C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
| y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
|
| return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
|
|
| def bimamba_inner_ref(
|
| xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
| out_proj_weight, out_proj_bias,
|
| A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
| C_proj_bias=None, delta_softplus=True
|
| ):
|
| L = xz.shape[-1]
|
| delta_rank = delta_proj_weight.shape[1]
|
| d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
| x, z = xz.chunk(2, dim=1)
|
| x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
|
|
|
|
|
|
|
| x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)
|
| delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
| delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
| if B is None:
|
| B = x_dbl[:, delta_rank:delta_rank + d_state]
|
| if B_proj_bias is not None:
|
| B = B + B_proj_bias.to(dtype=B.dtype)
|
| if not A.is_complex():
|
| B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| else:
|
| B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
| if C is None:
|
| C = x_dbl[:, -d_state:]
|
| if C_proj_bias is not None:
|
| C = C + C_proj_bias.to(dtype=C.dtype)
|
| if not A.is_complex():
|
| C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
| else:
|
| C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
| y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
|
| y_b = selective_scan_fn(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True)
|
| y = y + y_b.flip([-1])
|
| return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) |