Abdullah-Nazhat commited on
Commit
be02ece
·
verified ·
1 Parent(s): e787de7

Update pscan.py

Browse files
Files changed (1) hide show
  1. pscan.py +29 -78
pscan.py CHANGED
@@ -3,30 +3,15 @@ import math
3
  import torch
4
  import torch.nn.functional as F
5
 
6
- """
7
 
8
- An implementation of the parallel scan operation in PyTorch (Blelloch version).
9
- Please see docs/pscan.ipynb for a detailed explanation of what happens here.
10
-
11
- """
12
 
13
  def npo2(len):
14
- """
15
- Returns the next power of 2 above len
16
- """
17
 
18
  return 2 ** math.ceil(math.log2(len))
19
 
20
  def pad_npo2(X):
21
- """
22
- Pads input length dim to the next power of 2
23
-
24
- Args:
25
- X : (B, L, D, N)
26
-
27
- Returns:
28
- Y : (B, npo2(L), D, N)
29
- """
30
 
31
  len_npo2 = npo2(X.size(1))
32
  pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
@@ -35,20 +20,12 @@ def pad_npo2(X):
35
  class PScan(torch.autograd.Function):
36
  @staticmethod
37
  def pscan(A, X):
38
- # A : (B, D, L, N)
39
- # X : (B, D, L, N)
40
-
41
- # modifies X in place by doing a parallel scan.
42
- # more formally, X will be populated by these values :
43
- # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
44
- # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
45
-
46
- # only supports L that is a power of two (mainly for a clearer code)
47
 
48
  B, D, L, _ = A.size()
49
  num_steps = int(math.log2(L))
50
 
51
- # up sweep (last 2 steps unfolded)
52
  Aa = A
53
  Xa = X
54
  for _ in range(num_steps-2):
@@ -62,7 +39,7 @@ class PScan(torch.autograd.Function):
62
  Aa = Aa[:, :, :, 1]
63
  Xa = Xa[:, :, :, 1]
64
 
65
- # we have only 4, 2 or 1 nodes left
66
  if Xa.size(2) == 4:
67
  Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
68
  Aa[:, :, 1].mul_(Aa[:, :, 0])
@@ -74,7 +51,7 @@ class PScan(torch.autograd.Function):
74
  else:
75
  return
76
 
77
- # down sweep (first 2 steps unfolded)
78
  Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
79
  Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
80
  Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
@@ -93,19 +70,12 @@ class PScan(torch.autograd.Function):
93
 
94
  @staticmethod
95
  def pscan_rev(A, X):
96
- # A : (B, D, L, N)
97
- # X : (B, D, L, N)
98
-
99
- # the same function as above, but in reverse
100
- # (if you flip the input, call pscan, then flip the output, you get what this function outputs)
101
- # it is used in the backward pass
102
-
103
- # only supports L that is a power of two (mainly for a clearer code)
104
 
105
  B, D, L, _ = A.size()
106
  num_steps = int(math.log2(L))
107
 
108
- # up sweep (last 2 steps unfolded)
109
  Aa = A
110
  Xa = X
111
  for _ in range(num_steps-2):
@@ -119,7 +89,7 @@ class PScan(torch.autograd.Function):
119
  Aa = Aa[:, :, :, 0]
120
  Xa = Xa[:, :, :, 0]
121
 
122
- # we have only 4, 2 or 1 nodes left
123
  if Xa.size(2) == 4:
124
  Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
125
  Aa[:, :, 2].mul_(Aa[:, :, 3])
@@ -131,7 +101,7 @@ class PScan(torch.autograd.Function):
131
  else:
132
  return
133
 
134
- # down sweep (first 2 steps unfolded)
135
  Aa = A[:, :, 0:L:2**(num_steps-2)]
136
  Xa = X[:, :, 0:L:2**(num_steps-2)]
137
  Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
@@ -150,72 +120,53 @@ class PScan(torch.autograd.Function):
150
 
151
  @staticmethod
152
  def forward(ctx, A_in, X_in):
153
- """
154
- Applies the parallel scan operation, as defined above. Returns a new tensor.
155
- If you can, privilege sequence lengths that are powers of two.
156
-
157
- Args:
158
- A_in : (B, L, D, N)
159
- X_in : (B, L, D, N)
160
-
161
- Returns:
162
- H : (B, L, D, N)
163
- """
164
 
165
  L = X_in.size(1)
166
 
167
- # cloning is requiered because of the in-place ops
168
  if L == npo2(L):
169
  A = A_in.clone()
170
  X = X_in.clone()
171
  else:
172
- # pad tensors (and clone btw)
173
- A = pad_npo2(A_in) # (B, npo2(L), D, N)
174
- X = pad_npo2(X_in) # (B, npo2(L), D, N)
175
 
176
- # prepare tensors
177
- A = A.transpose(2, 1) # (B, D, npo2(L), N)
178
- X = X.transpose(2, 1) # (B, D, npo2(L), N)
179
 
180
- # parallel scan (modifies X in-place)
181
  PScan.pscan(A, X)
182
 
183
  ctx.save_for_backward(A_in, X)
184
 
185
- # slice [:, :L] (cut if there was padding)
186
  return X.transpose(2, 1)[:, :L]
187
 
188
  @staticmethod
189
  def backward(ctx, grad_output_in):
190
- """
191
- Flows the gradient from the output to the input. Returns two new tensors.
192
-
193
- Args:
194
- ctx : A_in : (B, L, D, N), X : (B, D, L, N)
195
- grad_output_in : (B, L, D, N)
196
-
197
- Returns:
198
- gradA : (B, L, D, N), gradX : (B, L, D, N)
199
- """
200
 
201
  A_in, X = ctx.saved_tensors
202
 
203
  L = grad_output_in.size(1)
204
 
205
- # cloning is requiered because of the in-place ops
206
  if L == npo2(L):
207
  grad_output = grad_output_in.clone()
208
- # the next padding will clone A_in
209
  else:
210
- grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
211
- A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
212
 
213
- # prepare tensors
214
  grad_output = grad_output.transpose(2, 1)
215
- A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
216
- A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
217
 
218
- # reverse parallel scan (modifies grad_output in-place)
219
  PScan.pscan_rev(A, grad_output)
220
 
221
  Q = torch.zeros_like(X)
 
3
  import torch
4
  import torch.nn.functional as F
5
 
 
6
 
 
 
 
 
7
 
8
  def npo2(len):
9
+
 
 
10
 
11
  return 2 ** math.ceil(math.log2(len))
12
 
13
  def pad_npo2(X):
14
+
 
 
 
 
 
 
 
 
15
 
16
  len_npo2 = npo2(X.size(1))
17
  pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
 
20
  class PScan(torch.autograd.Function):
21
  @staticmethod
22
  def pscan(A, X):
23
+
 
 
 
 
 
 
 
 
24
 
25
  B, D, L, _ = A.size()
26
  num_steps = int(math.log2(L))
27
 
28
+
29
  Aa = A
30
  Xa = X
31
  for _ in range(num_steps-2):
 
39
  Aa = Aa[:, :, :, 1]
40
  Xa = Xa[:, :, :, 1]
41
 
42
+
43
  if Xa.size(2) == 4:
44
  Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
45
  Aa[:, :, 1].mul_(Aa[:, :, 0])
 
51
  else:
52
  return
53
 
54
+
55
  Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
56
  Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
57
  Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
 
70
 
71
  @staticmethod
72
  def pscan_rev(A, X):
73
+
 
 
 
 
 
 
 
74
 
75
  B, D, L, _ = A.size()
76
  num_steps = int(math.log2(L))
77
 
78
+
79
  Aa = A
80
  Xa = X
81
  for _ in range(num_steps-2):
 
89
  Aa = Aa[:, :, :, 0]
90
  Xa = Xa[:, :, :, 0]
91
 
92
+
93
  if Xa.size(2) == 4:
94
  Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
95
  Aa[:, :, 2].mul_(Aa[:, :, 3])
 
101
  else:
102
  return
103
 
104
+
105
  Aa = A[:, :, 0:L:2**(num_steps-2)]
106
  Xa = X[:, :, 0:L:2**(num_steps-2)]
107
  Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
 
120
 
121
  @staticmethod
122
  def forward(ctx, A_in, X_in):
123
+
 
 
 
 
 
 
 
 
 
 
124
 
125
  L = X_in.size(1)
126
 
127
+
128
  if L == npo2(L):
129
  A = A_in.clone()
130
  X = X_in.clone()
131
  else:
132
+
133
+ A = pad_npo2(A_in)
134
+ X = pad_npo2(X_in)
135
 
136
+
137
+ A = A.transpose(2, 1)
138
+ X = X.transpose(2, 1)
139
 
140
+
141
  PScan.pscan(A, X)
142
 
143
  ctx.save_for_backward(A_in, X)
144
 
145
+
146
  return X.transpose(2, 1)[:, :L]
147
 
148
  @staticmethod
149
  def backward(ctx, grad_output_in):
150
+
 
 
 
 
 
 
 
 
 
151
 
152
  A_in, X = ctx.saved_tensors
153
 
154
  L = grad_output_in.size(1)
155
 
156
+
157
  if L == npo2(L):
158
  grad_output = grad_output_in.clone()
159
+
160
  else:
161
+ grad_output = pad_npo2(grad_output_in)
162
+ A_in = pad_npo2(A_in)
163
 
164
+
165
  grad_output = grad_output.transpose(2, 1)
166
+ A_in = A_in.transpose(2, 1)
167
+ A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1))
168
 
169
+
170
  PScan.pscan_rev(A, grad_output)
171
 
172
  Q = torch.zeros_like(X)