| | import math
|
| |
|
| | import torch
|
| | import torch.nn.functional as F
|
| |
|
| | """
|
| |
|
| | An implementation of the parallel scan operation in PyTorch (Blelloch version).
|
| | Please see docs/pscan.ipynb for a detailed explanation of what happens here.
|
| |
|
| | """
|
| |
|
| | def npo2(len):
|
| | """
|
| | Returns the next power of 2 above len
|
| | """
|
| |
|
| | return 2 ** math.ceil(math.log2(len))
|
| |
|
| | def pad_npo2(X):
|
| | """
|
| | Pads input length dim to the next power of 2
|
| |
|
| | Args:
|
| | X : (B, L, D, N)
|
| |
|
| | Returns:
|
| | Y : (B, npo2(L), D, N)
|
| | """
|
| |
|
| | len_npo2 = npo2(X.size(1))
|
| | pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
|
| | return F.pad(X, pad_tuple, "constant", 0)
|
| |
|
| | class PScan(torch.autograd.Function):
|
| | @staticmethod
|
| | def pscan(A, X):
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | B, D, L, _ = A.size()
|
| | num_steps = int(math.log2(L))
|
| |
|
| |
|
| | Aa = A
|
| | Xa = X
|
| | for _ in range(num_steps-2):
|
| | T = Xa.size(2)
|
| | Aa = Aa.view(B, D, T//2, 2, -1)
|
| | Xa = Xa.view(B, D, T//2, 2, -1)
|
| |
|
| | Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
|
| | Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
|
| |
|
| | Aa = Aa[:, :, :, 1]
|
| | Xa = Xa[:, :, :, 1]
|
| |
|
| |
|
| | if Xa.size(2) == 4:
|
| | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
|
| | Aa[:, :, 1].mul_(Aa[:, :, 0])
|
| |
|
| | Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
|
| | elif Xa.size(2) == 2:
|
| | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
|
| | return
|
| | else:
|
| | return
|
| |
|
| |
|
| | Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
|
| | Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
|
| | Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
|
| | Aa[:, :, 2].mul_(Aa[:, :, 1])
|
| |
|
| | for k in range(num_steps-3, -1, -1):
|
| | Aa = A[:, :, 2**k-1:L:2**k]
|
| | Xa = X[:, :, 2**k-1:L:2**k]
|
| |
|
| | T = Xa.size(2)
|
| | Aa = Aa.view(B, D, T//2, 2, -1)
|
| | Xa = Xa.view(B, D, T//2, 2, -1)
|
| |
|
| | Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
|
| | Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
|
| |
|
| | @staticmethod
|
| | def pscan_rev(A, X):
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | B, D, L, _ = A.size()
|
| | num_steps = int(math.log2(L))
|
| |
|
| |
|
| | Aa = A
|
| | Xa = X
|
| | for _ in range(num_steps-2):
|
| | T = Xa.size(2)
|
| | Aa = Aa.view(B, D, T//2, 2, -1)
|
| | Xa = Xa.view(B, D, T//2, 2, -1)
|
| |
|
| | Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
|
| | Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
|
| |
|
| | Aa = Aa[:, :, :, 0]
|
| | Xa = Xa[:, :, :, 0]
|
| |
|
| |
|
| | if Xa.size(2) == 4:
|
| | Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
|
| | Aa[:, :, 2].mul_(Aa[:, :, 3])
|
| |
|
| | Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
|
| | elif Xa.size(2) == 2:
|
| | Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
|
| | return
|
| | else:
|
| | return
|
| |
|
| |
|
| | Aa = A[:, :, 0:L:2**(num_steps-2)]
|
| | Xa = X[:, :, 0:L:2**(num_steps-2)]
|
| | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
|
| | Aa[:, :, 1].mul_(Aa[:, :, 2])
|
| |
|
| | for k in range(num_steps-3, -1, -1):
|
| | Aa = A[:, :, 0:L:2**k]
|
| | Xa = X[:, :, 0:L:2**k]
|
| |
|
| | T = Xa.size(2)
|
| | Aa = Aa.view(B, D, T//2, 2, -1)
|
| | Xa = Xa.view(B, D, T//2, 2, -1)
|
| |
|
| | Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
|
| | Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
|
| |
|
| | @staticmethod
|
| | def forward(ctx, A_in, X_in):
|
| | """
|
| | Applies the parallel scan operation, as defined above. Returns a new tensor.
|
| | If you can, privilege sequence lengths that are powers of two.
|
| |
|
| | Args:
|
| | A_in : (B, L, D, N)
|
| | X_in : (B, L, D, N)
|
| |
|
| | Returns:
|
| | H : (B, L, D, N)
|
| | """
|
| |
|
| | L = X_in.size(1)
|
| |
|
| |
|
| | if L == npo2(L):
|
| | A = A_in.clone()
|
| | X = X_in.clone()
|
| | else:
|
| |
|
| | A = pad_npo2(A_in)
|
| | X = pad_npo2(X_in)
|
| |
|
| |
|
| | A = A.transpose(2, 1)
|
| | X = X.transpose(2, 1)
|
| |
|
| |
|
| | PScan.pscan(A, X)
|
| |
|
| | ctx.save_for_backward(A_in, X)
|
| |
|
| |
|
| | return X.transpose(2, 1)[:, :L]
|
| |
|
| | @staticmethod
|
| | def backward(ctx, grad_output_in):
|
| | """
|
| | Flows the gradient from the output to the input. Returns two new tensors.
|
| |
|
| | Args:
|
| | ctx : A_in : (B, L, D, N), X : (B, D, L, N)
|
| | grad_output_in : (B, L, D, N)
|
| |
|
| | Returns:
|
| | gradA : (B, L, D, N), gradX : (B, L, D, N)
|
| | """
|
| |
|
| | A_in, X = ctx.saved_tensors
|
| |
|
| | L = grad_output_in.size(1)
|
| |
|
| |
|
| | if L == npo2(L):
|
| | grad_output = grad_output_in.clone()
|
| |
|
| | else:
|
| | grad_output = pad_npo2(grad_output_in)
|
| | A_in = pad_npo2(A_in)
|
| |
|
| |
|
| | grad_output = grad_output.transpose(2, 1)
|
| | A_in = A_in.transpose(2, 1)
|
| | A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1))
|
| |
|
| |
|
| | PScan.pscan_rev(A, grad_output)
|
| |
|
| | Q = torch.zeros_like(X)
|
| | Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
|
| |
|
| | return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
|
| |
|
| | pscan = PScan.apply |