Ubuntu commited on
Commit
a3d4355
·
1 Parent(s): 506a7e5

feat: torch 2.5.1+cu124

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /build/temp*
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This CUDA extension implements fused dropout + residual + LayerNorm, building on
2
+ Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
3
+ Major changes:
4
+ - Add dropout and residual.
5
+ - Make it work for both pre-norm and post-norm architecture.
6
+ - Support more hidden dimensions (all dimensions divisible by 8, up to 8192).
7
+ - Implement RMSNorm as an option.
8
+ - Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM).
9
+
10
+ If you want to use it for dimensions larger than 8k, please file an issue.
11
+
12
+ This extension has only been tested on A100s.
13
+
14
+ ```sh
15
+ cd csrc/layer_norm && pip install .
16
+ ```
17
+
18
+ As of 2024-01-05, this extension is no longer used in the FlashAttention repo.
19
+ We've instead switched to a Triton-based
20
+ [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py).
api.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
3
+
4
+ import dropout_layer_norm
5
+ import torch
6
+ from torch.nn import init
7
+
8
+
9
+ def maybe_align(x, alignment_in_bytes=16):
10
+ """Assume that x already has last dim divisible by alignment_in_bytes"""
11
+ # TD [2023-07-04] I'm not 100% sure that clone will align the memory
12
+ # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
13
+ return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
14
+
15
+
16
+ def _dropout_add_layer_norm_forward(
17
+ x0,
18
+ residual,
19
+ gamma,
20
+ beta,
21
+ rowscale,
22
+ colscale,
23
+ dropout_p,
24
+ epsilon,
25
+ residual_in_fp32=False,
26
+ is_rms_norm=False,
27
+ ):
28
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
29
+ hidden_size = gamma.numel()
30
+ x0mat = x0.view((-1, hidden_size))
31
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
32
+ rowscale = rowscale.view(-1) if rowscale is not None else None
33
+ zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
34
+ x0mat,
35
+ residualmat,
36
+ gamma,
37
+ beta,
38
+ rowscale,
39
+ colscale,
40
+ None,
41
+ None,
42
+ dropout_p,
43
+ epsilon,
44
+ 1.0,
45
+ 0,
46
+ None,
47
+ residual_in_fp32,
48
+ is_rms_norm,
49
+ )
50
+ # dmask is None if dropout_p == 0.0
51
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
52
+ return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
53
+
54
+
55
+ def _dropout_add_layer_norm_backward(
56
+ dz,
57
+ dx,
58
+ x,
59
+ x0,
60
+ dmask,
61
+ mu,
62
+ rsigma,
63
+ gamma,
64
+ rowscale,
65
+ colscale,
66
+ dropout_p,
67
+ has_residual,
68
+ is_rms_norm=False,
69
+ ):
70
+ """Assume that arguments are contiguous and aligned to 16 bytes
71
+ dx == None means that it was a post-norm architecture
72
+ (x = drop(x0) + residual was not returned in the fwd).
73
+ x0 must not be None if we have colscale.
74
+ """
75
+ hidden_size = gamma.numel()
76
+ xmat = x.view((-1, hidden_size))
77
+ dzmat = dz.view(xmat.shape)
78
+ dxmat = dx.view(xmat.shape) if dx is not None else None
79
+ x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
80
+ rowscale = rowscale.view(-1) if rowscale is not None else None
81
+ if colscale is not None:
82
+ assert x0 is not None, "x0 is required to compute the gradient of colscale"
83
+ dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
84
+ dzmat,
85
+ dxmat,
86
+ xmat,
87
+ x0mat,
88
+ dmask,
89
+ mu,
90
+ rsigma,
91
+ gamma,
92
+ rowscale,
93
+ colscale,
94
+ None,
95
+ None,
96
+ dropout_p,
97
+ 1.0,
98
+ 0,
99
+ has_residual,
100
+ is_rms_norm,
101
+ )
102
+ # dresidualmat is None if not has_residual
103
+ if colscale is None:
104
+ return dx0mat, dresidualmat, dgamma, dbeta
105
+ else:
106
+ dcolscale = rest[0]
107
+ return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
108
+
109
+
110
+ def _dropout_add_layer_norm_subset_forward(
111
+ x0,
112
+ residual,
113
+ gamma,
114
+ beta,
115
+ colscale,
116
+ x0_subset,
117
+ out_subset,
118
+ dropout_p,
119
+ epsilon,
120
+ rowscale_const,
121
+ out_numrows,
122
+ residual_in_fp32=False,
123
+ is_rms_norm=False,
124
+ ):
125
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
126
+ hidden_size = gamma.numel()
127
+ x0mat = x0.view((-1, hidden_size))
128
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
129
+ x0_subset = x0_subset.view(-1) if x0_subset is not None else None
130
+ out_subset = out_subset.view(-1) if out_subset is not None else None
131
+ zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
132
+ x0mat,
133
+ residualmat,
134
+ gamma,
135
+ beta,
136
+ None,
137
+ colscale,
138
+ x0_subset,
139
+ out_subset,
140
+ dropout_p,
141
+ epsilon,
142
+ rowscale_const,
143
+ out_numrows,
144
+ None,
145
+ residual_in_fp32,
146
+ is_rms_norm,
147
+ )
148
+ # dmask is None if dropout_p == 0.0
149
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
150
+ return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
151
+
152
+
153
+ def _dropout_add_layer_norm_subset_backward(
154
+ dz,
155
+ dx,
156
+ x,
157
+ x0,
158
+ dmask,
159
+ mu,
160
+ rsigma,
161
+ gamma,
162
+ colscale,
163
+ x0_subset,
164
+ out_subset,
165
+ dropout_p,
166
+ rowscale_const,
167
+ x0_numrows,
168
+ has_residual,
169
+ is_rms_norm=False,
170
+ ):
171
+ """Assume that arguments are contiguous and aligned to 16 bytes
172
+ dx == None means that it was a post-norm architecture
173
+ (x = drop(x0) + residual was not returned in the fwd).
174
+ x0 must not be None if we have colscale.
175
+ """
176
+ hidden_size = gamma.numel()
177
+ xmat = x.view((-1, hidden_size))
178
+ dzmat = dz.view(-1, hidden_size)
179
+ dxmat = dx.view(xmat.shape) if dx is not None else None
180
+ x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
181
+ x0_subset = x0_subset.view(-1) if x0_subset is not None else None
182
+ out_subset = out_subset.view(-1) if out_subset is not None else None
183
+ if colscale is not None:
184
+ assert x0 is not None, "x0 is required to compute the gradient of colscale"
185
+ dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
186
+ dzmat,
187
+ dxmat,
188
+ xmat,
189
+ x0mat,
190
+ dmask,
191
+ mu,
192
+ rsigma,
193
+ gamma,
194
+ None,
195
+ colscale,
196
+ x0_subset,
197
+ out_subset,
198
+ dropout_p,
199
+ rowscale_const,
200
+ x0_numrows,
201
+ has_residual,
202
+ is_rms_norm,
203
+ )
204
+ # dresidualmat is None if not has_residual
205
+ if colscale is None:
206
+ return dx0mat, dresidualmat, dgamma, dbeta
207
+ else:
208
+ dcolscale = rest[0]
209
+ return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
210
+
211
+
212
+ def _dropout_add_layer_norm_parallel_residual_forward(
213
+ x0,
214
+ x1,
215
+ residual,
216
+ gamma0,
217
+ beta0,
218
+ gamma1,
219
+ beta1,
220
+ dropout_p,
221
+ epsilon,
222
+ residual_in_fp32=False,
223
+ is_rms_norm=False,
224
+ ):
225
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
226
+ hidden_size = gamma0.numel()
227
+ x0mat = x0.view((-1, hidden_size))
228
+ x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
229
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
230
+ (
231
+ z0mat,
232
+ z1mat,
233
+ xmat,
234
+ dmask0,
235
+ dmask1,
236
+ mu,
237
+ rsigma,
238
+ ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
239
+ x0mat,
240
+ x1mat,
241
+ residualmat,
242
+ gamma0,
243
+ beta0,
244
+ gamma1,
245
+ beta1,
246
+ dropout_p,
247
+ epsilon,
248
+ None,
249
+ residual_in_fp32,
250
+ is_rms_norm,
251
+ )
252
+ # dmask0 and dmask1 are None if dropout_p == 0.0
253
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
254
+ return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
255
+
256
+
257
+ def _dropout_add_layer_norm_parallel_residual_backward(
258
+ dz0,
259
+ dz1,
260
+ dx,
261
+ x,
262
+ dmask0,
263
+ dmask1,
264
+ mu,
265
+ rsigma,
266
+ gamma0,
267
+ gamma1,
268
+ dropout_p,
269
+ has_x1,
270
+ has_residual,
271
+ is_rms_norm=False,
272
+ ):
273
+ """Assume that arguments are contiguous and aligned to 16 bytes
274
+ dx == None means that it was a post-norm architecture
275
+ (x = drop(x0) + residual was not returned in the fwd).
276
+ """
277
+ hidden_size = gamma0.numel()
278
+ xmat = x.view((-1, hidden_size))
279
+ dz0mat = dz0.view(xmat.shape)
280
+ dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
281
+ dxmat = dx.view(xmat.shape) if dx is not None else None
282
+ (
283
+ dx0mat,
284
+ dx1mat,
285
+ dresidualmat,
286
+ dgamma0,
287
+ dbeta0,
288
+ dgamma1,
289
+ dbeta1,
290
+ *rest,
291
+ ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
292
+ dz0mat,
293
+ dz1mat,
294
+ dxmat,
295
+ xmat,
296
+ dmask0,
297
+ dmask1,
298
+ mu,
299
+ rsigma,
300
+ gamma0,
301
+ gamma1,
302
+ dropout_p,
303
+ has_x1,
304
+ has_residual,
305
+ is_rms_norm,
306
+ )
307
+ # dresidualmat is None if not has_residual
308
+ return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
309
+
310
+
311
+ class DropoutAddLayerNormFn(torch.autograd.Function):
312
+ @staticmethod
313
+ def forward(
314
+ ctx,
315
+ x0,
316
+ residual,
317
+ gamma,
318
+ beta,
319
+ rowscale,
320
+ colscale,
321
+ dropout_p,
322
+ epsilon,
323
+ residual_in_fp32=False,
324
+ prenorm=False,
325
+ is_rms_norm=False,
326
+ return_dmask=False,
327
+ ):
328
+ x0 = maybe_align(x0.contiguous(), 16)
329
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
330
+ gamma = maybe_align(gamma.contiguous(), 16)
331
+ beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
332
+ rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
333
+ colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
334
+ zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
335
+ x0,
336
+ residual,
337
+ gamma,
338
+ beta,
339
+ rowscale,
340
+ colscale,
341
+ dropout_p,
342
+ epsilon,
343
+ residual_in_fp32,
344
+ is_rms_norm,
345
+ )
346
+ # Only need to save x0 if we need to compute gradient wrt colscale
347
+ x0_saved = x0 if colscale is not None else None
348
+ ctx.save_for_backward(
349
+ xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
350
+ )
351
+ ctx.prenorm = prenorm
352
+ ctx.dropout_p = dropout_p
353
+ ctx.has_residual = residual is not None
354
+ ctx.is_rms_norm = is_rms_norm
355
+ ctx.has_beta = beta is not None
356
+ if not return_dmask:
357
+ return (
358
+ zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
359
+ )
360
+ else:
361
+ dmask = (
362
+ dmask.view(x0.shape)
363
+ if dropout_p > 0.0
364
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
365
+ )
366
+ ctx.mark_non_differentiable(dmask)
367
+ return (
368
+ (zmat.view(x0.shape), dmask)
369
+ if not prenorm
370
+ else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
371
+ )
372
+
373
+ @staticmethod
374
+ def backward(ctx, dz, *args):
375
+ # assert dz.is_contiguous()
376
+ dz = maybe_align(dz.contiguous(), 16) # this happens!
377
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
378
+ x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
379
+ # x0 is None if colscale is None
380
+ dropout_p = ctx.dropout_p
381
+ has_residual = ctx.has_residual
382
+ dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
383
+ dz,
384
+ dx,
385
+ x,
386
+ x0,
387
+ dmask,
388
+ mu,
389
+ rsigma,
390
+ gamma,
391
+ rowscale,
392
+ colscale,
393
+ dropout_p,
394
+ has_residual,
395
+ ctx.is_rms_norm,
396
+ )
397
+ dx0 = dx0mat.view(x.shape)
398
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
399
+ dcolscale = rest[0] if colscale is not None else None
400
+ return (
401
+ dx0,
402
+ dresidual,
403
+ dgamma,
404
+ dbeta if ctx.has_beta else None,
405
+ None,
406
+ dcolscale,
407
+ None,
408
+ None,
409
+ None,
410
+ None,
411
+ None,
412
+ None,
413
+ )
414
+
415
+
416
+ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
417
+ @staticmethod
418
+ def forward(
419
+ ctx,
420
+ x0,
421
+ residual,
422
+ gamma,
423
+ beta,
424
+ colscale,
425
+ x0_subset,
426
+ out_subset,
427
+ dropout_p,
428
+ epsilon,
429
+ rowscale_const,
430
+ out_numrows,
431
+ residual_in_fp32=False,
432
+ prenorm=False,
433
+ is_rms_norm=False,
434
+ return_dmask=False,
435
+ ):
436
+ x0 = maybe_align(x0.contiguous(), 16)
437
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
438
+ gamma = maybe_align(gamma.contiguous(), 16)
439
+ beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
440
+ colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
441
+ zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
442
+ x0,
443
+ residual,
444
+ gamma,
445
+ beta,
446
+ colscale,
447
+ x0_subset,
448
+ out_subset,
449
+ dropout_p,
450
+ epsilon,
451
+ rowscale_const,
452
+ out_numrows,
453
+ residual_in_fp32,
454
+ is_rms_norm,
455
+ )
456
+ # Only need to save x0 if we need to compute gradient wrt colscale
457
+ x0_saved = x0 if colscale is not None else None
458
+ x_shape = (-1, *x0.shape[1:])
459
+ ctx.save_for_backward(
460
+ xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
461
+ )
462
+ ctx.prenorm = prenorm
463
+ ctx.dropout_p = dropout_p
464
+ ctx.rowscale_const = rowscale_const
465
+ ctx.x0_numrows = x0.shape[:-1].numel()
466
+ ctx.has_residual = residual is not None
467
+ ctx.is_rms_norm = is_rms_norm
468
+ ctx.has_beta = beta is not None
469
+ z_shape = (-1, *x0.shape[1:])
470
+ if not return_dmask:
471
+ return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
472
+ else:
473
+ z = zmat.view(z_shape)
474
+ dmask = (
475
+ dmask.view(x0.shape)
476
+ if dropout_p > 0.0
477
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
478
+ )
479
+ ctx.mark_non_differentiable(dmask)
480
+ return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
481
+
482
+ @staticmethod
483
+ def backward(ctx, dz, *args):
484
+ # assert dz.is_contiguous()
485
+ dz = maybe_align(dz.contiguous(), 16) # this happens!
486
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
487
+ x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
488
+ # x0 is None if colscale is None
489
+ dropout_p = ctx.dropout_p
490
+ has_residual = ctx.has_residual
491
+ dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
492
+ dz,
493
+ dx,
494
+ x,
495
+ x0,
496
+ dmask,
497
+ mu,
498
+ rsigma,
499
+ gamma,
500
+ colscale,
501
+ x0_subset,
502
+ out_subset,
503
+ dropout_p,
504
+ ctx.rowscale_const,
505
+ ctx.x0_numrows,
506
+ has_residual,
507
+ ctx.is_rms_norm,
508
+ )
509
+ dx0 = dx0mat.view(-1, *x.shape[1:])
510
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
511
+ dcolscale = rest[0] if colscale is not None else None
512
+ return (
513
+ dx0,
514
+ dresidual,
515
+ dgamma,
516
+ dbeta if ctx.has_beta else None,
517
+ dcolscale,
518
+ None,
519
+ None,
520
+ None,
521
+ None,
522
+ None,
523
+ None,
524
+ None,
525
+ None,
526
+ None,
527
+ None,
528
+ )
529
+
530
+
531
+ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
532
+ @staticmethod
533
+ def forward(
534
+ ctx,
535
+ x0,
536
+ x1,
537
+ residual,
538
+ gamma0,
539
+ beta0,
540
+ gamma1,
541
+ beta1,
542
+ dropout_p,
543
+ epsilon,
544
+ residual_in_fp32=False,
545
+ prenorm=False,
546
+ is_rms_norm=False,
547
+ return_dmask=False,
548
+ ):
549
+ x0 = maybe_align(x0.contiguous(), 16)
550
+ x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
551
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
552
+ gamma0 = maybe_align(gamma0.contiguous(), 16)
553
+ beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
554
+ gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
555
+ beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
556
+ (
557
+ z0mat,
558
+ z1mat,
559
+ xmat,
560
+ dmask0,
561
+ dmask1,
562
+ mu,
563
+ rsigma,
564
+ ) = _dropout_add_layer_norm_parallel_residual_forward(
565
+ x0,
566
+ x1,
567
+ residual,
568
+ gamma0,
569
+ beta0,
570
+ gamma1,
571
+ beta1,
572
+ dropout_p,
573
+ epsilon,
574
+ residual_in_fp32,
575
+ is_rms_norm,
576
+ )
577
+ ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
578
+ ctx.prenorm = prenorm
579
+ ctx.dropout_p = dropout_p
580
+ ctx.has_x1 = x1 is not None
581
+ ctx.has_residual = residual is not None
582
+ ctx.is_rms_norm = is_rms_norm
583
+ ctx.has_beta = beta0 is not None
584
+ z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
585
+ if not return_dmask:
586
+ return z if not prenorm else (*z, xmat.view(x0.shape))
587
+ else:
588
+ dmask0 = (
589
+ dmask0.view(x0.shape)
590
+ if dropout_p > 0.0
591
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
592
+ )
593
+ dmask1 = (
594
+ dmask1.view(x0.shape)
595
+ if dropout_p > 0.0 and x1 is not None
596
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
597
+ )
598
+ ctx.mark_non_differentiable(dmask0)
599
+ ctx.mark_non_differentiable(dmask1)
600
+ return (
601
+ (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
602
+ )
603
+
604
+ @staticmethod
605
+ def backward(ctx, dz0, dz1, *args):
606
+ dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
607
+ dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
608
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
609
+ x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
610
+ dropout_p = ctx.dropout_p
611
+ has_x1 = ctx.has_x1
612
+ has_residual = ctx.has_residual
613
+ (
614
+ dx0mat,
615
+ dx1mat,
616
+ dresidualmat,
617
+ dgamma0,
618
+ dbeta0,
619
+ dgamma1,
620
+ dbeta1,
621
+ ) = _dropout_add_layer_norm_parallel_residual_backward(
622
+ dz0,
623
+ dz1,
624
+ dx,
625
+ x,
626
+ dmask0,
627
+ dmask1,
628
+ mu,
629
+ rsigma,
630
+ gamma0,
631
+ gamma1,
632
+ dropout_p,
633
+ has_x1,
634
+ has_residual,
635
+ ctx.is_rms_norm,
636
+ )
637
+ dx0 = dx0mat.view(x.shape)
638
+ dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
639
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
640
+ return (
641
+ dx0,
642
+ dx1,
643
+ dresidual,
644
+ dgamma0,
645
+ dbeta0 if ctx.has_beta else None,
646
+ dgamma1,
647
+ dbeta1 if ctx.has_beta else None,
648
+ None,
649
+ None,
650
+ None,
651
+ None,
652
+ None,
653
+ None,
654
+ )
655
+
656
+
657
+ def layer_norm(x, weight, bias, epsilon):
658
+ return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
659
+
660
+
661
+ def dropout_add_layer_norm(
662
+ x0,
663
+ residual,
664
+ weight,
665
+ bias,
666
+ dropout_p,
667
+ epsilon,
668
+ rowscale=None,
669
+ layerscale=None,
670
+ prenorm=False,
671
+ residual_in_fp32=False,
672
+ return_dropout_mask=False,
673
+ ):
674
+ """residual_in_fp32 only has an effect if residual is None.
675
+ Otherwise residual dtype is residual.dtype.
676
+ """
677
+ return DropoutAddLayerNormFn.apply(
678
+ x0,
679
+ residual,
680
+ weight,
681
+ bias,
682
+ rowscale,
683
+ layerscale,
684
+ dropout_p,
685
+ epsilon,
686
+ residual_in_fp32,
687
+ prenorm,
688
+ False,
689
+ return_dropout_mask,
690
+ )
691
+
692
+
693
+ def dropout_add_layer_norm_subset(
694
+ x0,
695
+ residual,
696
+ weight,
697
+ bias,
698
+ dropout_p,
699
+ epsilon,
700
+ layerscale=None,
701
+ x0_subset=None,
702
+ out_subset=None,
703
+ rowscale_const=1.0,
704
+ out_numrows=0,
705
+ prenorm=False,
706
+ residual_in_fp32=False,
707
+ return_dropout_mask=False,
708
+ ):
709
+ """residual_in_fp32 only has an effect if residual is None.
710
+ Otherwise residual dtype is residual.dtype.
711
+ """
712
+ return DropoutAddLayerNormSubsetFn.apply(
713
+ x0,
714
+ residual,
715
+ weight,
716
+ bias,
717
+ layerscale,
718
+ x0_subset,
719
+ out_subset,
720
+ dropout_p,
721
+ epsilon,
722
+ rowscale_const,
723
+ out_numrows,
724
+ residual_in_fp32,
725
+ prenorm,
726
+ False,
727
+ return_dropout_mask,
728
+ )
729
+
730
+
731
+ def dropout_add_layer_norm_parallel_residual(
732
+ x0,
733
+ x1,
734
+ residual,
735
+ weight0,
736
+ bias0,
737
+ weight1,
738
+ bias1,
739
+ dropout_p,
740
+ epsilon,
741
+ prenorm=False,
742
+ residual_in_fp32=False,
743
+ return_dropout_mask=False,
744
+ ):
745
+ """residual_in_fp32 only has an effect if residual is None.
746
+ Otherwise residual dtype is residual.dtype.
747
+ """
748
+ return DropoutAddLayerNormParallelResidualFn.apply(
749
+ x0,
750
+ x1,
751
+ residual,
752
+ weight0,
753
+ bias0,
754
+ weight1,
755
+ bias1,
756
+ dropout_p,
757
+ epsilon,
758
+ residual_in_fp32,
759
+ prenorm,
760
+ False,
761
+ return_dropout_mask,
762
+ )
763
+
764
+
765
+ class DropoutAddLayerNorm(torch.nn.Module):
766
+ def __init__(
767
+ self,
768
+ hidden_size,
769
+ prenorm=False,
770
+ p=0.0,
771
+ eps=1e-5,
772
+ residual_in_fp32=False,
773
+ device=None,
774
+ dtype=None,
775
+ ):
776
+ factory_kwargs = {"device": device, "dtype": dtype}
777
+ super().__init__()
778
+ self.prenorm = prenorm
779
+ self.p = p
780
+ self.eps = eps
781
+ self.residual_in_fp32 = residual_in_fp32
782
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
783
+ self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
784
+ self.reset_parameters()
785
+
786
+ def reset_parameters(self):
787
+ init.ones_(self.weight)
788
+ init.zeros_(self.bias)
789
+
790
+ def forward(self, x0, residual=None):
791
+ return dropout_add_layer_norm(
792
+ x0,
793
+ residual,
794
+ self.weight,
795
+ self.bias,
796
+ self.p if self.training else 0.0,
797
+ self.eps,
798
+ prenorm=self.prenorm,
799
+ residual_in_fp32=self.residual_in_fp32,
800
+ )
build/lib.linux-x86_64-3.10/dropout_layer_norm.cpython-310-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aba674c175147bfdff6acb354745749070519df35f66522d71b2743aedc3b5a9
3
+ size 26705096
ln.h ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <unordered_map>
4
+ #include <cuda_fp16.h>
5
+ #include <cuda_bf16.h>
6
+
7
+ #ifdef OLD_GENERATOR_PATH
8
+ #include <ATen/CUDAGeneratorImpl.h>
9
+ #else
10
+ #include <ATen/cuda/CUDAGeneratorImpl.h>
11
+ #endif
12
+
13
+ namespace layer_norm {
14
+
15
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
16
+
17
+ template<typename Params>
18
+ struct LaunchParams{
19
+
20
+ size_t elts_per_thread;
21
+ size_t workspace_bytes;
22
+ size_t barrier_size;
23
+
24
+ cudaDeviceProp * props;
25
+
26
+ cudaStream_t stream;
27
+
28
+ Params params;
29
+
30
+ };
31
+
32
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
33
+
34
+ struct ParamsBase {
35
+ ParamsBase()
36
+ : ctas_per_col(0)
37
+ , rows(0)
38
+ , cols(0)
39
+ , x(nullptr)
40
+ , mu(nullptr)
41
+ , rs(nullptr)
42
+ , gamma(nullptr)
43
+ , gamma1(nullptr)
44
+ , rowscale(nullptr)
45
+ , colscale(nullptr)
46
+ , dropout_keep_p(1.f)
47
+ , dropout_scale(1.f)
48
+ , is_rms_norm(false)
49
+ , workspace(nullptr)
50
+ , barrier(nullptr)
51
+ {
52
+ }
53
+
54
+ // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
55
+ int ctas_per_col;
56
+
57
+ // Input is interpreted as matrix. We normalize across columns.
58
+ int rows;
59
+ int cols;
60
+
61
+ // Common data pointers.
62
+ void *x0;
63
+ void *x1;
64
+ void *residual;
65
+ void *x;
66
+ void *dmask;
67
+ void *dmask1;
68
+ void *mu;
69
+ void *rs;
70
+ void *gamma;
71
+ void *gamma1;
72
+ void *rowscale;
73
+ void *colscale;
74
+ void *x0_subset;
75
+ void *z_subset;
76
+
77
+ float inverse_cols;
78
+
79
+ float dropout_keep_p;
80
+ float dropout_scale;
81
+ float rowscale_const;
82
+
83
+ bool is_rms_norm;
84
+
85
+ // Multi-CTA workspace in gmem.
86
+ void *workspace;
87
+
88
+ // Multi-CTA sync barriers in gmem.
89
+ int *barrier;
90
+
91
+ };
92
+
93
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
94
+
95
+ struct FwdParams : public ParamsBase {
96
+ FwdParams()
97
+ : ParamsBase()
98
+ , z(nullptr)
99
+ , z1(nullptr)
100
+ , beta(nullptr)
101
+ , beta1(nullptr)
102
+ , epsilon(0.f)
103
+ {
104
+ }
105
+
106
+ // Output of LN FWD.
107
+ void *z;
108
+ void *z1;
109
+ void *beta;
110
+ void *beta1;
111
+ float epsilon;
112
+
113
+ // Random state.
114
+ at::PhiloxCudaState philox_args;
115
+ };
116
+
117
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
118
+
119
+ struct BwdParams : public ParamsBase {
120
+ BwdParams()
121
+ : ParamsBase()
122
+ , dz(nullptr)
123
+ , dz1(nullptr)
124
+ , dx(nullptr)
125
+ , dbeta_part(nullptr)
126
+ , dgamma_part(nullptr)
127
+ , dbeta1_part(nullptr)
128
+ , dgamma1_part(nullptr)
129
+ , dcolscale_part(nullptr)
130
+ , dx0(nullptr)
131
+ , dx1(nullptr)
132
+ , dresidual(nullptr)
133
+ , dbeta(nullptr)
134
+ , dgamma(nullptr)
135
+ , dbeta1(nullptr)
136
+ , dgamma1(nullptr)
137
+ , dcolscale(nullptr)
138
+ {
139
+ }
140
+
141
+ // Input: gradient wrt. LN FWD output.
142
+ void *dz;
143
+ void *dz1;
144
+ // Input: gradient wrt residual.
145
+ void *dx;
146
+
147
+ // Workspace for Wgrad pre-reduction.
148
+ void *dbeta_part;
149
+ void *dgamma_part;
150
+ void *dbeta1_part;
151
+ void *dgamma1_part;
152
+ void *dcolscale_part;
153
+
154
+ // Output: Dgrad.
155
+ void *dx0;
156
+ void *dx1;
157
+ void *dresidual;
158
+ // Output: Wgrad.
159
+ void *dbeta;
160
+ void *dgamma;
161
+ void *dbeta1;
162
+ void *dgamma1;
163
+ void *dcolscale;
164
+
165
+ };
166
+
167
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
168
+
169
+ using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
170
+ using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
171
+ using FunctionKey = uint64_t;
172
+ using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
173
+ using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
174
+
175
+ extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
176
+ extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
177
+
178
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
179
+
180
+ using fp32 = float;
181
+ using fp16 = half;
182
+ using bf16 = nv_bfloat16;
183
+
184
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
185
+
186
+ template<typename T>
187
+ struct TypeId{};
188
+
189
+ template<>
190
+ struct TypeId<fp16>{
191
+ constexpr static uint32_t Value = 0;
192
+ };
193
+
194
+ template<>
195
+ struct TypeId<bf16>{
196
+ constexpr static uint32_t Value = 1;
197
+ };
198
+
199
+ template<>
200
+ struct TypeId<fp32>{
201
+ constexpr static uint32_t Value = 2;
202
+ };
203
+
204
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
205
+
206
+ template<typename T, int S>
207
+ struct Type2Key{
208
+ constexpr static uint32_t Value = TypeId<T>::Value << S;
209
+ };
210
+
211
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
212
+
213
+ template<typename T>
214
+ struct WeightType2Key : public Type2Key<T, 0>{};
215
+
216
+ template<typename T>
217
+ struct InputType2Key : public Type2Key<T, 2>{};
218
+
219
+ template<typename T>
220
+ struct ResidualType2Key : public Type2Key<T, 4>{};
221
+
222
+ template<typename T>
223
+ struct OutputType2Key : public Type2Key<T, 6>{};
224
+
225
+ template<typename T>
226
+ struct ComputeType2Key : public Type2Key<T, 8>{};
227
+
228
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
229
+
230
+ template<typename W, typename I, typename R, typename O, typename C>
231
+ struct Types2Key{
232
+ constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
233
+ constexpr static inline uint64_t get(const uint64_t hidden_size){
234
+ constexpr uint64_t type_key = Value;
235
+ return (type_key << 32) | hidden_size;
236
+ }
237
+ };
238
+
239
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
240
+
241
+ template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
242
+ struct FwdRegistrar{
243
+ FwdRegistrar(FwdFunction f){
244
+ uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
245
+ FWD_FUNCS.insert({ key, f });
246
+ }
247
+ };
248
+
249
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
250
+
251
+ template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
252
+ struct BwdRegistrar{
253
+ BwdRegistrar(BwdFunction f){
254
+ uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
255
+ BWD_FUNCS.insert({ key, f });
256
+ }
257
+ };
258
+
259
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
260
+
261
+ template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
262
+ struct FwdParallelRegistrar{
263
+ FwdParallelRegistrar(FwdFunction f){
264
+ uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
265
+ PARALLEL_FWD_FUNCS.insert({ key, f });
266
+ }
267
+ };
268
+
269
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
270
+
271
+ template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
272
+ struct BwdParallelRegistrar{
273
+ BwdParallelRegistrar(BwdFunction f){
274
+ uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
275
+ PARALLEL_BWD_FUNCS.insert({ key, f });
276
+ }
277
+ };
278
+
279
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
280
+
281
+ } // namespace layer_norm
ln_api.cpp ADDED
@@ -0,0 +1,850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/cuda/CUDAContext.h"
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ #include "ln.h"
6
+
7
+ /*
8
+
9
+ Supported Type combinations:
10
+
11
+ input residual compute weights output
12
+ ============================================
13
+ fp32 fp32 fp32 fp32 fp32
14
+ fp16 fp32 fp32 fp32 fp16
15
+ fp16 fp16 fp32 fp32 fp16
16
+ bf16 fp32 fp32 fp32 bf16
17
+ bf16 bf16 fp32 fp32 bf16
18
+ fp16 fp16 fp32 fp16 fp16
19
+ bf16 bf16 fp32 bf16 bf16
20
+
21
+ Remarks:
22
+ Output type = Input type
23
+ Compute always in FP32
24
+
25
+ */
26
+
27
+ namespace layer_norm {
28
+
29
+ // Create registries and provide runtime versions of config hash functions.
30
+
31
+ FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
32
+ BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
33
+
34
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
35
+
36
+ uint32_t get_type_id(torch::Dtype dtype){
37
+ if( dtype == torch::kFloat16 ) {
38
+ return TypeId<fp16>::Value;
39
+ } else if( dtype == torch::kBFloat16 ) {
40
+ return TypeId<bf16>::Value;
41
+ } else if( dtype == torch::kFloat32 ) {
42
+ return TypeId<fp32>::Value;
43
+ } else {
44
+ TORCH_CHECK(false, "Type not supported: ", dtype);
45
+ }
46
+ }
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
51
+ using namespace layer_norm;
52
+ uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8);
53
+ uint64_t launcher_key = (type_key << 32) | hidden_size;
54
+ return launcher_key;
55
+ }
56
+
57
+ } // namespace layer_norm
58
+
59
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
60
+
61
+ layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
62
+ auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
63
+ if( iter != layer_norm::FWD_FUNCS.end() ) {
64
+ return iter->second;
65
+ } else {
66
+ TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
67
+ }
68
+ }
69
+
70
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
71
+
72
+ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
73
+ auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
74
+ if( iter != layer_norm::BWD_FUNCS.end() ) {
75
+ return iter->second;
76
+ } else {
77
+ TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
78
+ }
79
+ }
80
+
81
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
82
+
83
+ layer_norm::FwdFunction & get_parallel_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
84
+ auto iter = layer_norm::PARALLEL_FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
85
+ if( iter != layer_norm::PARALLEL_FWD_FUNCS.end() ) {
86
+ return iter->second;
87
+ } else {
88
+ TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
89
+ }
90
+ }
91
+
92
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
93
+
94
+ layer_norm::BwdFunction & get_parallel_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
95
+ auto iter = layer_norm::PARALLEL_BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
96
+ if( iter != layer_norm::PARALLEL_BWD_FUNCS.end() ) {
97
+ return iter->second;
98
+ } else {
99
+ TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
100
+ }
101
+ }
102
+
103
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
104
+
105
+ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
106
+ c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
107
+ const at::Tensor &gamma, // hidden_size
108
+ c10::optional<const at::Tensor> &beta_, // hidden_size
109
+ c10::optional<const at::Tensor> &rowscale_, // BxS
110
+ c10::optional<const at::Tensor> &colscale_, // hidden_size
111
+ c10::optional<const at::Tensor> &x0_subset_, // BxS
112
+ c10::optional<const at::Tensor> &z_subset_, // BxS
113
+ const float dropout_p,
114
+ const float epsilon,
115
+ const float rowscale_const,
116
+ const int64_t z_numrows,
117
+ c10::optional<at::Generator> gen_,
118
+ bool residual_in_fp32=false,
119
+ bool is_rms_norm=false
120
+ ) {
121
+ auto itype = x0.scalar_type();
122
+ auto rtype = residual_.has_value()
123
+ ? residual_.value().scalar_type()
124
+ : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
125
+ auto wtype = gamma.scalar_type();
126
+ auto otype = itype;
127
+ auto ctype = torch::kFloat32;
128
+ auto mtype = torch::kUInt8;
129
+
130
+ TORCH_CHECK(x0.is_cuda());
131
+ TORCH_CHECK(gamma.is_cuda());
132
+
133
+ TORCH_CHECK(x0.is_contiguous());
134
+ // c10::IntArrayRef does not own the storage, so we need to construct a vector.
135
+ // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because
136
+ // blah is then deallocated.
137
+ std::vector<int64_t> sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)};
138
+ auto sizes = c10::IntArrayRef(sizes_vec);
139
+ TORCH_CHECK(x0.dim() == 2);
140
+ TORCH_CHECK(sizes.size() == 2);
141
+
142
+ const int rows = sizes[0];
143
+ const int cols = sizes[1];
144
+ auto hidden_size = gamma.numel();
145
+ TORCH_CHECK(hidden_size == cols);
146
+
147
+ if (beta_.has_value()) {
148
+ auto beta = beta_.value();
149
+ TORCH_CHECK(beta.dtype() == wtype);
150
+ TORCH_CHECK(beta.is_cuda());
151
+ TORCH_CHECK(beta.is_contiguous());
152
+ TORCH_CHECK(beta.sizes() == gamma.sizes());
153
+ }
154
+
155
+ if (residual_.has_value()) {
156
+ auto residual = residual_.value();
157
+ TORCH_CHECK(residual.is_cuda());
158
+ TORCH_CHECK(residual.is_contiguous());
159
+ TORCH_CHECK(residual.sizes() == sizes);
160
+ }
161
+
162
+ if (rowscale_.has_value()) {
163
+ auto rowscale = rowscale_.value();
164
+ TORCH_CHECK(rowscale.is_cuda());
165
+ TORCH_CHECK(rowscale.is_contiguous());
166
+ TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
167
+ TORCH_CHECK(rowscale.dtype() == itype);
168
+ }
169
+
170
+ if (colscale_.has_value()) {
171
+ auto colscale = colscale_.value();
172
+ TORCH_CHECK(colscale.is_cuda());
173
+ TORCH_CHECK(colscale.is_contiguous());
174
+ TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
175
+ TORCH_CHECK(colscale.dtype() == wtype);
176
+ }
177
+
178
+ if (x0_subset_.has_value()) {
179
+ auto x0_subset = x0_subset_.value();
180
+ TORCH_CHECK(x0_subset.is_cuda());
181
+ TORCH_CHECK(x0_subset.is_contiguous());
182
+ TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
183
+ TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
184
+
185
+ TORCH_CHECK(z_subset_.has_value());
186
+ auto z_subset = z_subset_.value();
187
+ TORCH_CHECK(z_subset.is_cuda());
188
+ TORCH_CHECK(z_subset.is_contiguous());
189
+ TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
190
+ TORCH_CHECK(z_subset.dtype() == torch::kInt32);
191
+ }
192
+
193
+ TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
194
+ TORCH_CHECK(epsilon >= 0.f);
195
+
196
+ // Otherwise the kernel will be launched from cuda:0 device
197
+ // Cast to char to avoid compiler warning about narrowing
198
+ at::cuda::CUDAGuard device_guard{(char)x0.get_device()};
199
+
200
+ auto opts = x0.options();
201
+
202
+ bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
203
+ at::Tensor x;
204
+ if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
205
+ at::Tensor dmask;
206
+ if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); };
207
+ auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype));
208
+
209
+ auto mu = torch::empty({ rows }, opts.dtype(ctype));
210
+ auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
211
+
212
+ layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
213
+
214
+ launch_params.props = at::cuda::getCurrentDeviceProperties();
215
+ launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
216
+ TORCH_CHECK(dropout_p < 1.f);
217
+ launch_params.params.dropout_keep_p = 1.f - dropout_p;
218
+ launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
219
+ launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
220
+ launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
221
+ launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
222
+ launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
223
+
224
+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
225
+ gen_, at::cuda::detail::getDefaultCUDAGenerator());
226
+
227
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
228
+ const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
229
+ // Request the kernel launcher.
230
+ auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
231
+
232
+ // Set the kernel runtime parameters.
233
+ layer_norm::FwdParams &params = launch_params.params;
234
+ params.rows = rows;
235
+ params.cols = cols;
236
+ params.x0 = x0.data_ptr();
237
+ params.x = save_x ? x.data_ptr() : nullptr;
238
+ params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr;
239
+ params.mu = mu.data_ptr();
240
+ params.rs = rsigma.data_ptr();
241
+ params.gamma = gamma.data_ptr();
242
+ params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr;
243
+ params.z = z.data_ptr();
244
+ params.epsilon = epsilon;
245
+ params.dropout_scale = 1.f / (1.f - dropout_p);
246
+ params.inverse_cols = 1.f / float(params.cols);
247
+ params.rowscale_const = rowscale_const;
248
+ params.is_rms_norm = is_rms_norm;
249
+
250
+ // Query the kernel-specific launch parameters.
251
+ launcher(launch_params, true);
252
+
253
+ at::Tensor workspace, barrier;
254
+
255
+ if (dropout_p > 0.f) {
256
+ // number of times random will be generated per thread, to offset philox counter in thc random
257
+ // state
258
+ int64_t counter_offset = launch_params.elts_per_thread;
259
+
260
+ // See Note [Acquire lock when using random generators]
261
+ {
262
+ std::lock_guard<std::mutex> lock(gen->mutex_);
263
+ params.philox_args = gen->philox_cuda_state(counter_offset);
264
+ }
265
+ }
266
+
267
+ if( launch_params.barrier_size > 0 ) {
268
+ auto options = x0.options();
269
+ barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
270
+ workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
271
+ params.workspace = workspace.data_ptr();
272
+ params.barrier = barrier.data_ptr<int>();
273
+ }
274
+
275
+ // Launch the kernel.
276
+ launcher(launch_params, false);
277
+
278
+ return { z, x, dmask, mu, rsigma };
279
+ }
280
+
281
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
282
+
283
+ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size
284
+ c10::optional<const at::Tensor> &dx_, // BxSxhidden_size
285
+ const at::Tensor &x, // BxSxhidden_size
286
+ c10::optional<const at::Tensor> &x0_, // BxSxhidden_size
287
+ c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
288
+ const at::Tensor &mu, // BxS, FP32!
289
+ const at::Tensor &rsigma, // BxS, FP32!
290
+ const at::Tensor &gamma, // hidden_size
291
+ c10::optional<const at::Tensor> &rowscale_, // BxS
292
+ c10::optional<const at::Tensor> &colscale_, // hidden_size
293
+ c10::optional<const at::Tensor> &x0_subset_, // BxS
294
+ c10::optional<const at::Tensor> &z_subset_, // BxS
295
+ const float dropout_p,
296
+ const float rowscale_const,
297
+ const int64_t x0_numrows,
298
+ const bool has_residual,
299
+ bool is_rms_norm=false
300
+ ) {
301
+
302
+ auto itype = dz.scalar_type();
303
+ auto rtype = x.scalar_type();
304
+ auto wtype = gamma.scalar_type();
305
+ auto otype = itype;
306
+ auto ctype = torch::kFloat32;
307
+ auto mtype = torch::kUInt8;
308
+
309
+ if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
310
+
311
+ TORCH_CHECK(dz.dtype() == otype);
312
+ TORCH_CHECK(mu.dtype() == ctype);
313
+ TORCH_CHECK(rsigma.dtype() == ctype);
314
+
315
+ TORCH_CHECK(x.is_cuda());
316
+ TORCH_CHECK(dz.is_cuda());
317
+ TORCH_CHECK(mu.is_cuda());
318
+ TORCH_CHECK(rsigma.is_cuda());
319
+ TORCH_CHECK(gamma.is_cuda());
320
+
321
+ TORCH_CHECK(x.is_contiguous());
322
+ TORCH_CHECK(dz.is_contiguous());
323
+
324
+ auto sizes = x.sizes();
325
+ TORCH_CHECK(sizes.size() == 2);
326
+ auto rows = sizes[0];
327
+ auto cols = sizes[1];
328
+ TORCH_CHECK(dz.dim() == 2);
329
+ TORCH_CHECK(dz.size(1) == cols);
330
+ auto hidden_size = gamma.numel();
331
+ TORCH_CHECK(hidden_size == cols);
332
+
333
+ // c10::IntArrayRef does not own the storage, so we need to construct a vector.
334
+ // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because
335
+ // blah is then deallocated.
336
+ std::vector<int64_t> x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols};
337
+ auto x0_sizes = c10::IntArrayRef(x0_sizes_vec);
338
+
339
+ if (dx_.has_value()) {
340
+ auto dx = dx_.value();
341
+ TORCH_CHECK(dx.dtype() == rtype);
342
+ TORCH_CHECK(dx.is_cuda());
343
+ TORCH_CHECK(dx.is_contiguous());
344
+ TORCH_CHECK(dx.sizes() == sizes);
345
+ }
346
+
347
+ if (dmask_.has_value()) {
348
+ auto dmask = dmask_.value();
349
+ TORCH_CHECK(dmask.dtype() == mtype);
350
+ TORCH_CHECK(dmask.is_cuda());
351
+ TORCH_CHECK(dmask.is_contiguous());
352
+ TORCH_CHECK(dmask.sizes() == x0_sizes);
353
+ }
354
+
355
+ if (rowscale_.has_value()) {
356
+ auto rowscale = rowscale_.value();
357
+ TORCH_CHECK(rowscale.is_cuda());
358
+ TORCH_CHECK(rowscale.is_contiguous());
359
+ TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
360
+ TORCH_CHECK(rowscale.dtype() == itype);
361
+ }
362
+
363
+ if (colscale_.has_value()) {
364
+ auto colscale = colscale_.value();
365
+ TORCH_CHECK(colscale.is_cuda());
366
+ TORCH_CHECK(colscale.is_contiguous());
367
+ TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
368
+ TORCH_CHECK(colscale.dtype() == wtype);
369
+
370
+ TORCH_CHECK(x0_.has_value());
371
+ auto x0 = x0_.value();
372
+ TORCH_CHECK(x0.is_cuda());
373
+ TORCH_CHECK(x0.is_contiguous());
374
+ TORCH_CHECK(x0.sizes() == x0_sizes);
375
+ TORCH_CHECK(x0.dtype() == itype);
376
+ }
377
+
378
+ if (x0_subset_.has_value()) {
379
+ auto x0_subset = x0_subset_.value();
380
+ TORCH_CHECK(x0_subset.is_cuda());
381
+ TORCH_CHECK(x0_subset.is_contiguous());
382
+ TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
383
+ TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
384
+
385
+ TORCH_CHECK(z_subset_.has_value());
386
+ auto z_subset = z_subset_.value();
387
+ TORCH_CHECK(z_subset.is_cuda());
388
+ TORCH_CHECK(z_subset.is_contiguous());
389
+ TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
390
+ TORCH_CHECK(z_subset.dtype() == torch::kInt32);
391
+ }
392
+
393
+ TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
394
+
395
+ TORCH_CHECK(mu.numel() == rows);
396
+ TORCH_CHECK(mu.sizes() == rsigma.sizes());
397
+
398
+ TORCH_CHECK(gamma.numel() == cols);
399
+
400
+ // Otherwise the kernel will be launched from cuda:0 device
401
+ // Cast to char to avoid compiler warning about narrowing
402
+ at::cuda::CUDAGuard device_guard{(char)dz.get_device()};
403
+
404
+ auto opts = x.options();
405
+
406
+ auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
407
+ at::Tensor dresidual;
408
+ if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
409
+ auto dgamma = torch::empty_like(gamma);
410
+ auto dbeta = torch::empty_like(gamma);
411
+ at::Tensor dcolscale;
412
+ if (colscale_.has_value()) {
413
+ dcolscale = torch::empty_like(colscale_.value());
414
+ }
415
+
416
+ layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
417
+ launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
418
+ launch_params.props = at::cuda::getCurrentDeviceProperties();
419
+ TORCH_CHECK(dropout_p < 1.f);
420
+ launch_params.params.dropout_keep_p = 1.f - dropout_p;
421
+ launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
422
+ launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
423
+ launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
424
+ launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
425
+ launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
426
+
427
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
428
+ const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
429
+ auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
430
+
431
+ launcher(launch_params, true);
432
+
433
+ auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
434
+ auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
435
+ at::Tensor dcolscale_part;
436
+ if (colscale_.has_value()) {
437
+ dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
438
+ }
439
+ at::Tensor workspace, barrier;
440
+
441
+ layer_norm::BwdParams &params = launch_params.params;
442
+ params.rows = rows;
443
+ params.cols = cols;
444
+ params.x = x.data_ptr();
445
+ params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr;
446
+ params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
447
+ params.mu = mu.data_ptr();
448
+ params.rs = rsigma.data_ptr();
449
+ params.gamma = gamma.data_ptr();
450
+ params.dz = dz.data_ptr();
451
+ params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
452
+ params.dx0 = dx0.data_ptr();
453
+ params.dbeta = dbeta.data_ptr();
454
+ params.dgamma = dgamma.data_ptr();
455
+ params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr;
456
+ params.dbeta_part = dbeta_part.data_ptr();
457
+ params.dgamma_part = dgamma_part.data_ptr();
458
+ params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
459
+ params.dropout_scale = 1.f / (1.f - dropout_p);
460
+ params.inverse_cols = 1.f / float(params.cols);
461
+ params.rowscale_const = rowscale_const;
462
+ params.is_rms_norm = is_rms_norm;
463
+
464
+ if( launch_params.barrier_size > 0 ) {
465
+ // TODO Any way to avoid this?
466
+ barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
467
+ workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
468
+ params.workspace = workspace.data_ptr();
469
+ params.barrier = barrier.data_ptr<int>();
470
+ }
471
+
472
+ launcher(launch_params, false);
473
+
474
+ std::vector<at::Tensor> result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part };
475
+ if (colscale_.has_value()) {
476
+ result.push_back(dcolscale);
477
+ result.push_back(dcolscale_part);
478
+ }
479
+ return result;
480
+ }
481
+
482
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
483
+
484
+ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
485
+ const at::Tensor &x0, // Input: BxSxhidden_size
486
+ c10::optional<const at::Tensor> &x1_, // Input: BxSxhidden_size
487
+ c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
488
+ const at::Tensor &gamma0, // hidden_size
489
+ c10::optional<const at::Tensor> &beta0_, // hidden_size
490
+ c10::optional<const at::Tensor> &gamma1_, // hidden_size
491
+ c10::optional<const at::Tensor> &beta1_, // hidden_size
492
+ const float dropout_p,
493
+ const float epsilon,
494
+ c10::optional<at::Generator> gen_,
495
+ bool residual_in_fp32=false,
496
+ bool is_rms_norm=false
497
+ ) {
498
+ auto itype = x0.scalar_type();
499
+ auto rtype = residual_.has_value()
500
+ ? residual_.value().scalar_type()
501
+ : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
502
+ auto wtype = gamma0.scalar_type();
503
+ auto otype = itype;
504
+ auto ctype = torch::kFloat32;
505
+ auto mtype = torch::kUInt8;
506
+
507
+ TORCH_CHECK(x0.is_cuda());
508
+ TORCH_CHECK(gamma0.is_cuda());
509
+
510
+ TORCH_CHECK(x0.is_contiguous());
511
+ const auto sizes = x0.sizes();
512
+ TORCH_CHECK(x0.dim() == 2);
513
+
514
+ const int rows = sizes[0];
515
+ const int cols = sizes[1];
516
+ auto hidden_size = gamma0.numel();
517
+ TORCH_CHECK(hidden_size == cols);
518
+
519
+ if (x1_.has_value()) {
520
+ auto x1 = x1_.value();
521
+ TORCH_CHECK(x1.is_cuda());
522
+ TORCH_CHECK(x1.is_contiguous());
523
+ TORCH_CHECK(x1.sizes() == sizes);
524
+ }
525
+
526
+ if (residual_.has_value()) {
527
+ auto residual = residual_.value();
528
+ TORCH_CHECK(residual.is_cuda());
529
+ TORCH_CHECK(residual.is_contiguous());
530
+ TORCH_CHECK(residual.sizes() == sizes);
531
+ }
532
+
533
+ if (beta0_.has_value()) {
534
+ auto beta0 = beta0_.value();
535
+ TORCH_CHECK(beta0.dtype() == wtype);
536
+ TORCH_CHECK(beta0.is_cuda());
537
+ TORCH_CHECK(beta0.is_contiguous());
538
+ TORCH_CHECK(beta0.sizes() == gamma0.sizes());
539
+ }
540
+
541
+ if (gamma1_.has_value()) {
542
+ auto gamma1 = gamma1_.value();
543
+ TORCH_CHECK(gamma1.dtype() == wtype);
544
+ TORCH_CHECK(gamma1.is_cuda());
545
+ TORCH_CHECK(gamma1.is_contiguous());
546
+ TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
547
+ }
548
+
549
+ if (beta1_.has_value()) {
550
+ auto beta1 = beta1_.value();
551
+ TORCH_CHECK(beta1.dtype() == wtype);
552
+ TORCH_CHECK(beta1.is_cuda());
553
+ TORCH_CHECK(beta1.is_contiguous());
554
+ TORCH_CHECK(beta1.sizes() == gamma0.sizes());
555
+ }
556
+
557
+ TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
558
+ TORCH_CHECK(epsilon >= 0.f);
559
+
560
+ // Otherwise the kernel will be launched from cuda:0 device
561
+ // Cast to char to avoid compiler warning about narrowing
562
+ at::cuda::CUDAGuard device_guard{(char)x0.get_device()};
563
+
564
+ auto opts = x0.options();
565
+
566
+ bool save_x = residual_.has_value() || x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
567
+ at::Tensor x;
568
+ if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
569
+ at::Tensor dmask0, dmask1;
570
+ if (dropout_p > 0.f) {
571
+ dmask0 = torch::empty(x0.sizes(), opts.dtype(mtype));
572
+ if (x1_.has_value()) { dmask1 = torch::empty(x0.sizes(), opts.dtype(mtype)); }
573
+ };
574
+ auto z0 = torch::empty(sizes, opts.dtype(otype));
575
+ at::Tensor z1;
576
+ if (gamma1_.has_value()) { z1 = torch::empty(sizes, opts.dtype(otype)); }
577
+
578
+ auto mu = torch::empty({ rows }, opts.dtype(ctype));
579
+ auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
580
+
581
+ layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
582
+
583
+ launch_params.props = at::cuda::getCurrentDeviceProperties();
584
+ launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
585
+ TORCH_CHECK(dropout_p < 1.f);
586
+ launch_params.params.dropout_keep_p = 1.f - dropout_p;
587
+ launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
588
+
589
+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
590
+ gen_, at::cuda::detail::getDefaultCUDAGenerator());
591
+
592
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
593
+ const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
594
+ // Request the kernel launcher.
595
+ auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
596
+
597
+ // Set the kernel runtime parameters.
598
+ layer_norm::FwdParams &params = launch_params.params;
599
+ params.rows = rows;
600
+ params.cols = cols;
601
+ params.x0 = x0.data_ptr();
602
+ params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
603
+ params.x = save_x ? x.data_ptr() : nullptr;
604
+ params.dmask = dropout_p > 0.f ? dmask0.data_ptr() : nullptr;
605
+ params.dmask1 = (dropout_p > 0.f && x1_.has_value()) ? dmask1.data_ptr() : nullptr;
606
+ params.mu = mu.data_ptr();
607
+ params.rs = rsigma.data_ptr();
608
+ params.gamma = gamma0.data_ptr();
609
+ params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
610
+ params.beta = beta0_.has_value() ? beta0_.value().data_ptr() : nullptr;
611
+ params.beta1 = beta1_.has_value() ? beta1_.value().data_ptr() : nullptr;
612
+ params.z = z0.data_ptr();
613
+ params.z1 = gamma1_.has_value() ? z1.data_ptr() : nullptr;
614
+ params.epsilon = epsilon;
615
+ params.dropout_scale = 1.f / (1.f - dropout_p);
616
+ params.inverse_cols = 1.f / float(params.cols);
617
+ params.is_rms_norm = is_rms_norm;
618
+
619
+ // Query the kernel-specific launch parameters.
620
+ launcher(launch_params, true);
621
+
622
+ at::Tensor workspace, barrier;
623
+
624
+ if (dropout_p > 0.f) {
625
+ // number of times random will be generated per thread, to offset philox counter in thc random
626
+ // state
627
+ int64_t counter_offset = 2 * launch_params.elts_per_thread;
628
+
629
+ // See Note [Acquire lock when using random generators]
630
+ {
631
+ std::lock_guard<std::mutex> lock(gen->mutex_);
632
+ params.philox_args = gen->philox_cuda_state(counter_offset);
633
+ }
634
+ }
635
+
636
+ if( launch_params.barrier_size > 0 ) {
637
+ auto options = x0.options();
638
+ barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
639
+ workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
640
+ params.workspace = workspace.data_ptr();
641
+ params.barrier = barrier.data_ptr<int>();
642
+ }
643
+
644
+ // Launch the kernel.
645
+ launcher(launch_params, false);
646
+
647
+ return { z0, z1, x, dmask0, dmask1, mu, rsigma };
648
+ }
649
+
650
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
651
+
652
+ std::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd(
653
+ const at::Tensor &dz0, // BxSxhidden_size
654
+ c10::optional<const at::Tensor> &dz1_, // BxSxhidden_size
655
+ c10::optional<const at::Tensor> &dx_, // BxSxhidden_size
656
+ const at::Tensor &x, // BxSxhidden_size
657
+ c10::optional<const at::Tensor> &dmask0_, // BxSxhidden_size
658
+ c10::optional<const at::Tensor> &dmask1_, // BxSxhidden_size
659
+ const at::Tensor &mu, // BxS, FP32!
660
+ const at::Tensor &rsigma, // BxS, FP32!
661
+ const at::Tensor &gamma0, // hidden_size
662
+ c10::optional<const at::Tensor> &gamma1_, // hidden_size
663
+ const float dropout_p,
664
+ const bool has_x1,
665
+ const bool has_residual,
666
+ bool is_rms_norm=false
667
+ ) {
668
+
669
+ auto itype = dz0.scalar_type();
670
+ auto rtype = x.scalar_type();
671
+ auto wtype = gamma0.scalar_type();
672
+ auto otype = itype;
673
+ auto ctype = torch::kFloat32;
674
+ auto mtype = torch::kUInt8;
675
+
676
+ if (dropout_p > 0.f) { TORCH_CHECK(dmask0_.has_value()); }
677
+
678
+ TORCH_CHECK(dz0.dtype() == otype);
679
+ TORCH_CHECK(dz0.dtype() == otype);
680
+ TORCH_CHECK(mu.dtype() == ctype);
681
+ TORCH_CHECK(rsigma.dtype() == ctype);
682
+
683
+ TORCH_CHECK(x.is_cuda());
684
+ TORCH_CHECK(dz0.is_cuda());
685
+ TORCH_CHECK(mu.is_cuda());
686
+ TORCH_CHECK(rsigma.is_cuda());
687
+ TORCH_CHECK(gamma0.is_cuda());
688
+
689
+ TORCH_CHECK(x.is_contiguous());
690
+ TORCH_CHECK(dz0.is_contiguous());
691
+
692
+ auto sizes = x.sizes();
693
+ TORCH_CHECK(sizes.size() == 2);
694
+ auto rows = sizes[0];
695
+ auto cols = sizes[1];
696
+ TORCH_CHECK(dz0.dim() == 2);
697
+ TORCH_CHECK(dz0.size(1) == cols);
698
+ auto hidden_size = gamma0.numel();
699
+ TORCH_CHECK(hidden_size == cols);
700
+
701
+ if (dz1_.has_value()) {
702
+ auto dz1 = dz1_.value();
703
+ TORCH_CHECK(dz1.dtype() == otype);
704
+ TORCH_CHECK(dz1.is_cuda());
705
+ TORCH_CHECK(dz1.is_contiguous());
706
+ TORCH_CHECK(dz1.sizes() == sizes);
707
+
708
+ TORCH_CHECK(gamma1_.has_value());
709
+ auto gamma1 = gamma1_.value();
710
+ TORCH_CHECK(gamma1.dtype() == wtype);
711
+ TORCH_CHECK(gamma1.is_cuda());
712
+ TORCH_CHECK(gamma1.is_contiguous());
713
+ TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
714
+ }
715
+
716
+ if (dx_.has_value()) {
717
+ auto dx = dx_.value();
718
+ TORCH_CHECK(dx.dtype() == rtype);
719
+ TORCH_CHECK(dx.is_cuda());
720
+ TORCH_CHECK(dx.is_contiguous());
721
+ TORCH_CHECK(dx.sizes() == sizes);
722
+ }
723
+
724
+ if (dmask0_.has_value()) {
725
+ auto dmask0 = dmask0_.value();
726
+ TORCH_CHECK(dmask0.dtype() == mtype);
727
+ TORCH_CHECK(dmask0.is_cuda());
728
+ TORCH_CHECK(dmask0.is_contiguous());
729
+ TORCH_CHECK(dmask0.sizes() == sizes);
730
+
731
+ if (has_x1) {
732
+ TORCH_CHECK(dmask1_.has_value());
733
+ auto dmask1 = dmask1_.value();
734
+ TORCH_CHECK(dmask1.dtype() == mtype);
735
+ TORCH_CHECK(dmask1.is_cuda());
736
+ TORCH_CHECK(dmask1.is_contiguous());
737
+ TORCH_CHECK(dmask1.sizes() == sizes);
738
+ }
739
+ }
740
+
741
+ TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
742
+
743
+ TORCH_CHECK(mu.numel() == rows);
744
+ TORCH_CHECK(mu.sizes() == rsigma.sizes());
745
+
746
+ // Otherwise the kernel will be launched from cuda:0 device
747
+ // Cast to char to avoid compiler warning about narrowing
748
+ at::cuda::CUDAGuard device_guard{(char)dz0.get_device()};
749
+
750
+ auto opts = x.options();
751
+
752
+ auto dx0 = torch::empty(sizes, opts.dtype(itype));
753
+ at::Tensor dx1;
754
+ if (has_x1) { dx1 = torch::empty(sizes, opts.dtype(itype)); }
755
+ at::Tensor dresidual;
756
+ if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
757
+ auto dgamma0 = torch::empty_like(gamma0);
758
+ auto dbeta0 = torch::empty_like(gamma0);
759
+ at::Tensor dgamma1, dbeta1;
760
+ if (gamma1_.has_value()) {
761
+ dgamma1 = torch::empty_like(gamma0);
762
+ dbeta1 = torch::empty_like(gamma0);
763
+ }
764
+
765
+ layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
766
+ launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
767
+ launch_params.props = at::cuda::getCurrentDeviceProperties();
768
+ TORCH_CHECK(dropout_p < 1.f);
769
+ launch_params.params.dropout_keep_p = 1.f - dropout_p;
770
+ launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
771
+
772
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
773
+ const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
774
+ auto launcher = get_parallel_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
775
+
776
+ launcher(launch_params, true);
777
+
778
+ auto dgamma0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
779
+ auto dbeta0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
780
+ at::Tensor dgamma1_part, dbeta1_part;
781
+ if (gamma1_.has_value()) {
782
+ dgamma1_part = torch::zeros_like(dgamma0_part);
783
+ dbeta1_part = torch::zeros_like(dbeta0_part);
784
+ }
785
+ at::Tensor workspace, barrier;
786
+
787
+ layer_norm::BwdParams &params = launch_params.params;
788
+ params.rows = rows;
789
+ params.cols = cols;
790
+ params.x = x.data_ptr();
791
+ params.dmask = dropout_p > 0.f ? dmask0_.value().data_ptr() : nullptr;
792
+ params.dmask1 = (dropout_p > 0.f && has_x1) ? dmask1_.value().data_ptr() : nullptr;
793
+ params.mu = mu.data_ptr();
794
+ params.rs = rsigma.data_ptr();
795
+ params.gamma = gamma0.data_ptr();
796
+ params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
797
+ params.dz = dz0.data_ptr();
798
+ params.dz1 = dz1_.has_value() ? dz1_.value().data_ptr() : nullptr;
799
+ params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
800
+ params.dx0 = dx0.data_ptr();
801
+ params.dx1 = has_x1 ? dx1.data_ptr() : nullptr;
802
+ params.dbeta = dbeta0.data_ptr();
803
+ params.dgamma = dgamma0.data_ptr();
804
+ params.dbeta1 = gamma1_.has_value() ? dbeta1.data_ptr() : nullptr;
805
+ params.dgamma1 = gamma1_.has_value() ? dgamma1.data_ptr() : nullptr;
806
+ params.dbeta_part = dbeta0_part.data_ptr();
807
+ params.dgamma_part = dgamma0_part.data_ptr();
808
+ params.dbeta1_part = gamma1_.has_value() ? dbeta1_part.data_ptr() : nullptr;
809
+ params.dgamma1_part = gamma1_.has_value() ? dgamma1_part.data_ptr() : nullptr;
810
+ params.dropout_scale = 1.f / (1.f - dropout_p);
811
+ params.inverse_cols = 1.f / float(params.cols);
812
+ params.is_rms_norm = is_rms_norm;
813
+
814
+ if( launch_params.barrier_size > 0 ) {
815
+ // TODO Any way to avoid this?
816
+ barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
817
+ workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
818
+ params.workspace = workspace.data_ptr();
819
+ params.barrier = barrier.data_ptr<int>();
820
+ }
821
+
822
+ launcher(launch_params, false);
823
+
824
+ std::vector<at::Tensor> result = { dx0, dx1, dresidual, dgamma0, dbeta0, dgamma1, dbeta1, dgamma0_part, dbeta0_part, dgamma1_part, dbeta1_part };
825
+ return result;
826
+ }
827
+
828
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
829
+
830
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
831
+ m.doc() = "CUDA DropoutAddLayerNorm";
832
+ m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
833
+ py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta_"),
834
+ py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
835
+ py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
836
+ py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
837
+ m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel",
838
+ py::arg("dz"), py::arg("dx_"), py::arg("x"), py::arg("x0_"), py::arg("dmask_"), py::arg("mu"),
839
+ py::arg("rsigma"), py::arg("gamma"), py::arg("rowscale_"), py::arg("colscale_"),
840
+ py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("rowscale_const"),
841
+ py::arg("x0_numrows"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
842
+ m.def("dropout_add_ln_parallel_residual_fwd", &dropout_add_ln_parallel_residual_fwd, "Run Dropout + Add + LayerNorm parallel residual forward kernel",
843
+ py::arg("x0"), py::arg("x1_"), py::arg("residual"), py::arg("gamma0"), py::arg("beta0_"),
844
+ py::arg("gamma1_"), py::arg("beta1_"), py::arg("dropout_p"), py::arg("epsilon"),
845
+ py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
846
+ m.def("dropout_add_ln_parallel_residual_bwd", &dropout_add_ln_parallel_residual_bwd, "Run Dropout + Add + LayerNorm parallel residual backward kernel",
847
+ py::arg("dz0"), py::arg("dz1_"), py::arg("dx_"), py::arg("x"), py::arg("dmask0_"),
848
+ py::arg("dmask1_"), py::arg("mu"), py::arg("rsigma"), py::arg("gamma0"), py::arg("gamma1_"),
849
+ py::arg("dropout_p"), py::arg("has_x1"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
850
+ }
ln_bwd_1024.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
+ REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
+ REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
+ REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
+ REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
+ REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
+ REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
+ REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
ln_bwd_1280.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
+ REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
+ REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
+ REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
+ REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
+ REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
+ REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
+ REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
ln_bwd_1536.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
9
+ REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
10
+ REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
11
+ REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
12
+ REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
13
+ REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
14
+ REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
15
+ REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
ln_bwd_2048.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
9
+ REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
10
+ REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
11
+ REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
12
+ REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
13
+ REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
14
+ REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
15
+ REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
ln_bwd_256.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
+ REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
+ REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
+ REGISTER_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
+ REGISTER_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
+ REGISTER_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
+ REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
+ REGISTER_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
ln_bwd_2560.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
9
+ REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
10
+ REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
11
+ REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
12
+ REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
13
+ REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
14
+ REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
15
+ REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
ln_bwd_3072.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
9
+ REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
10
+ REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
11
+ REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
12
+ REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
13
+ REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
14
+ REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
15
+ REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
ln_bwd_4096.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
9
+ REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
10
+ REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
11
+ REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
12
+ REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
13
+ REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
14
+ REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
15
+ REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
ln_bwd_512.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
+ REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
+ REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
+ REGISTER_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
+ REGISTER_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
+ REGISTER_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
+ REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
+ REGISTER_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
ln_bwd_5120.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
9
+ REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
10
+ REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
11
+ REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
12
+ REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
13
+ REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
14
+ REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
15
+ REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
ln_bwd_6144.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
9
+ REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
10
+ REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
11
+ REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
12
+ REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
13
+ REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
14
+ REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
15
+ REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
ln_bwd_7168.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
9
+ REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
10
+ REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
11
+ REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
12
+ REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
13
+ REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
14
+ REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
15
+ REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
ln_bwd_768.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
+ REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
+ REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
+ REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
+ REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
+ REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
+ REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
+ REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
ln_bwd_8192.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
7
+ REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
8
+ REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
9
+ REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
10
+ REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
11
+ REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
12
+ REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
13
+ REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
14
+ REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
15
+ REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
ln_bwd_kernels.cuh ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ln.h"
4
+ #include "ln_utils.cuh"
5
+ #include "ln_kernel_traits.h"
6
+ #include "static_switch.h"
7
+
8
+ namespace layer_norm {
9
+
10
+ template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
11
+ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
12
+ void ln_bwd_kernel(layer_norm::BwdParams params) {
13
+
14
+ enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
15
+ enum { WARPS_M = Ktraits::WARPS_M };
16
+ enum { WARPS_N = Ktraits::WARPS_N };
17
+ enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
18
+ enum { COLS = Ktraits::COLS };
19
+ enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
20
+ enum { LDGS = Ktraits::LDGS };
21
+ enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
22
+ enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
23
+ enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
24
+
25
+ using input_t = typename Ktraits::input_t;
26
+ using compute_t = typename Ktraits::compute_t;
27
+ using index_t = typename Ktraits::index_t;
28
+ using mask_t = typename Ktraits::mask_t;
29
+ using Ivec = typename Ktraits::Ivec;
30
+ using Rvec = typename Ktraits::Rvec;
31
+ using Ovec = typename Ktraits::Ovec;
32
+ using Wvec = typename Ktraits::Wvec;
33
+ using Cvec = typename Ktraits::Cvec;
34
+ using Mvec = typename Ktraits::Mvec;
35
+ using Reducer = typename Ktraits::Reducer;
36
+ using reduce_t = typename Reducer::Type;
37
+
38
+ extern __shared__ char smem_[];
39
+
40
+ const bool has_residual = params.dresidual != nullptr;
41
+ const bool prenorm = params.dx != nullptr;
42
+
43
+ const index_t tidx = threadIdx.x;
44
+ const index_t bidn = blockIdx.x % CTAS_PER_ROW;
45
+ const index_t bidm = blockIdx.x / CTAS_PER_ROW;
46
+ const index_t lane = tidx % THREADS_PER_WARP;
47
+ const index_t warp = tidx / THREADS_PER_WARP;
48
+ const index_t warp_m = warp / Ktraits::WARPS_N;
49
+ const index_t warp_n = warp % Ktraits::WARPS_N;
50
+ const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
51
+
52
+ const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
53
+ const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
54
+
55
+ static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
56
+
57
+ const input_t *rowscale = static_cast<input_t *>(params.rowscale);
58
+ const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
59
+ const index_t *z_subset = static_cast<index_t *>(params.z_subset);
60
+
61
+ Cvec dzy_sum[LDGS];
62
+ Cvec dz_sum[LDGS];
63
+ Cvec dcolscale_sum[LDGS];
64
+
65
+ memset(dzy_sum, 0, sizeof(dzy_sum));
66
+ memset(dz_sum, 0, sizeof(dz_sum));
67
+ if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); }
68
+
69
+ compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
70
+ char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
71
+
72
+ Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
73
+
74
+ Sum<reduce_t> sum;
75
+
76
+ const index_t num_valid_ldgs =
77
+ ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
78
+
79
+ Wvec gamma[LDGS];
80
+ Wvec colscale[LDGS];
81
+ index_t idx = c;
82
+ #pragma unroll
83
+ for( int it = 0; it < LDGS; it++ ) {
84
+ if (Is_even_cols || (it < num_valid_ldgs)) {
85
+ gamma[it].load_from(params.gamma, idx);
86
+ if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
87
+ idx += Ktraits::VEC_COLS_PER_LDG;
88
+ }
89
+ }
90
+ // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
91
+ // last blocks with syncthreads!
92
+ // grid stride over rows
93
+ #pragma unroll 1
94
+ for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
95
+ const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
96
+ const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
97
+ const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
98
+ const int row_z = !Has_subset ? row + 1 : z_subset[row];
99
+ const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
100
+ const bool load_dz = !Has_subset || row_z > 0;
101
+ const bool save_dx0 = !Has_subset || row_x0 > 0;
102
+ Mvec dmask[LDGS];
103
+ Rvec dx[LDGS];
104
+ compute_t dy[LDGS * NUM_ELTS];
105
+ compute_t y[LDGS * NUM_ELTS];
106
+ compute_t mdy_local = 0.f;
107
+ compute_t mdyy_local = 0.f;
108
+ // If dz is not loaded, then dy should be 0 and we don't care about the value of y.
109
+ if (load_dz) {
110
+ index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
111
+ index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
112
+ index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
113
+ #pragma unroll
114
+ for( int it = 0; it < LDGS; it++ ) {
115
+ if (Is_even_cols || (it < num_valid_ldgs)) {
116
+ Rvec x;
117
+ Ovec dz;
118
+ dz.load_from(params.dz, !Has_subset ? idx_x : idx_z);
119
+ if (prenorm) { dx[it].load_from(params.dx, idx_x); }
120
+ x.load_from(params.x, idx_x);
121
+ if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
122
+ idx_x += Ktraits::VEC_COLS_PER_LDG;
123
+ idx_z += Ktraits::VEC_COLS_PER_LDG;
124
+ idx_x0 += Ktraits::VEC_COLS_PER_LDG;
125
+ #pragma unroll
126
+ for( int jt = 0; jt < NUM_ELTS; jt++ ) {
127
+ compute_t x_tmp = x.data.elt[jt];
128
+ compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f));
129
+ compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);
130
+ compute_t dz_tmp = dz.data.elt[jt];
131
+
132
+ mdy_local += dy_tmp;
133
+ mdyy_local += dy_tmp * y_tmp;
134
+
135
+ dy[it * NUM_ELTS + jt] = dy_tmp;
136
+ y[it * NUM_ELTS + jt] = y_tmp;
137
+
138
+ dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
139
+ dz_sum[it].data.elt[jt] += dz_tmp;
140
+ }
141
+ }
142
+ }
143
+ } else {
144
+ index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
145
+ index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
146
+ #pragma unroll
147
+ for( int it = 0; it < LDGS; it++ ) {
148
+ if (Is_even_cols || (it < num_valid_ldgs)) {
149
+ if (prenorm) { dx[it].load_from(params.dx, idx_x); }
150
+ if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
151
+ idx_x += Ktraits::VEC_COLS_PER_LDG;
152
+ idx_x0 += Ktraits::VEC_COLS_PER_LDG;
153
+ }
154
+ }
155
+ }
156
+
157
+ reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
158
+ mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;
159
+ mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;
160
+
161
+ index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
162
+ index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
163
+ #pragma unroll
164
+ for( int it = 0; it < LDGS; it++ ) {
165
+ if (Is_even_cols || (it < num_valid_ldgs)) {
166
+ Ivec dx0;
167
+ Rvec dresidual;
168
+ Ivec x0;
169
+ if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
170
+ #pragma unroll
171
+ for( int jt = 0; jt < NUM_ELTS; jt++ ) {
172
+ compute_t dx_tmp_res;
173
+ if (load_dz) {
174
+ compute_t dy_tmp = dy[it * NUM_ELTS + jt];
175
+ compute_t y_tmp = y[it * NUM_ELTS + jt];
176
+ compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f)));
177
+ dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
178
+ } else {
179
+ dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
180
+ }
181
+ if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
182
+ if (save_dx0) {
183
+ compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
184
+ if (Is_dropout) {
185
+ dx0_tmp_res *= params.dropout_scale;
186
+ if (Has_colscale) {
187
+ dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
188
+ dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
189
+ } else {
190
+ dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
191
+ }
192
+ } else {
193
+ if (Has_colscale) {
194
+ dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
195
+ dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
196
+ } else {
197
+ dx0.data.elt[jt] = dx0_tmp_res;
198
+ }
199
+ }
200
+ }
201
+ }
202
+ if (has_residual) { dresidual.store_to(params.dresidual, idx_x); }
203
+ if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
204
+ idx_x += Ktraits::VEC_COLS_PER_LDG;
205
+ idx_x0 += Ktraits::VEC_COLS_PER_LDG;
206
+ }
207
+ }
208
+
209
+ } // end: grid stride loop
210
+
211
+ if( WARPS_M == 1 ) {
212
+ idx = r * params.cols / Ktraits::ELTS_PER_LDG + c;
213
+ #pragma unroll
214
+ for( int it = 0; it < LDGS; it++ ) {
215
+ if (Is_even_cols || (it < num_valid_ldgs)) {
216
+ dz_sum[it].store_to(params.dbeta_part, idx);
217
+ dzy_sum[it].store_to(params.dgamma_part, idx);
218
+ if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); }
219
+ idx += Ktraits::VEC_COLS_PER_LDG;
220
+ }
221
+ }
222
+ } else {
223
+ static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
224
+ // Finalize reduction of part dgamma and dbeta for this CTA
225
+ // by reducing over the rows held across the WARPS_M warps
226
+
227
+ // Assumption: blockSize divides hidden size.
228
+ enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
229
+ static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
230
+
231
+ idx = warp_m * Ktraits::VEC_COLS + tid_r;
232
+ #pragma unroll
233
+ for( int it = 0; it < LDGS; it++ ) {
234
+ dz_sum[it].store_to(smem_wgrad, idx);
235
+ idx += THREADS_PER_ROW;
236
+ }
237
+ __syncthreads();
238
+ compute_t cta_dz_sum[NUM_RES];
239
+ memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
240
+ for( int it = 0; it < ROWS_PER_CTA; it++ ) {
241
+ for( int jt = 0; jt < NUM_RES; jt++ ) {
242
+ cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
243
+ }
244
+ }
245
+ __syncthreads();
246
+
247
+ idx = warp_m * Ktraits::VEC_COLS + tid_r;
248
+ #pragma unroll
249
+ for( int it = 0; it < LDGS; it++ ) {
250
+ dzy_sum[it].store_to(smem_wgrad, idx);
251
+ idx += THREADS_PER_ROW;
252
+ }
253
+ __syncthreads();
254
+ compute_t cta_dzy_sum[NUM_RES];
255
+ memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
256
+ for( int it = 0; it < ROWS_PER_CTA; it++ ) {
257
+ for( int jt = 0; jt < NUM_RES; jt++ ) {
258
+ cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
259
+ }
260
+ }
261
+
262
+ compute_t cta_dcolscale_sum[NUM_RES];
263
+ if (Has_colscale) {
264
+ __syncthreads();
265
+ idx = warp_m * Ktraits::VEC_COLS + tid_r;
266
+ #pragma unroll
267
+ for( int it = 0; it < LDGS; it++ ) {
268
+ dcolscale_sum[it].store_to(smem_wgrad, idx);
269
+ idx += THREADS_PER_ROW;
270
+ }
271
+ __syncthreads();
272
+ memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES);
273
+ for( int it = 0; it < ROWS_PER_CTA; it++ ) {
274
+ for( int jt = 0; jt < NUM_RES; jt++ ) {
275
+ cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
276
+ }
277
+ }
278
+ }
279
+
280
+ const index_t num_valid_writes
281
+ = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
282
+ compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
283
+ compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
284
+ compute_t *dcolscale_part = Has_colscale ? static_cast<compute_t *>(params.dcolscale_part) + bidm * params.cols + tidx : nullptr;
285
+ for( int jt = 0; jt < NUM_RES; jt++ ) {
286
+ if (Is_even_cols || (jt < num_valid_writes)) {
287
+ *dgamma_part = cta_dzy_sum[jt];
288
+ dgamma_part += Ktraits::THREADS_PER_CTA;
289
+ *dbeta_part = cta_dz_sum[jt];
290
+ dbeta_part += Ktraits::THREADS_PER_CTA;
291
+ if (Has_colscale) {
292
+ *dcolscale_part = cta_dcolscale_sum[jt];
293
+ dcolscale_part += Ktraits::THREADS_PER_CTA;
294
+ }
295
+ }
296
+ }
297
+
298
+ }
299
+ }
300
+
301
+ template<typename Kernel_traits, bool Has_colscale, bool Is_even_cols>
302
+ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
303
+ void ln_bwd_finalize_kernel(BwdParams params)
304
+ {
305
+
306
+ using compute_t = typename Kernel_traits::compute_t;
307
+ using weight_t = typename Kernel_traits::weight_t;
308
+ using index_t = typename Kernel_traits::index_t;
309
+ using Reducer = typename Kernel_traits::Reducer;
310
+ using reduce_t = typename Reducer::Type;
311
+
312
+ Sum<reduce_t> sum;
313
+ enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
314
+ enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
315
+
316
+ __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
317
+
318
+ constexpr uint32_t bidm = 0;
319
+
320
+ const uint32_t bidn = blockIdx.x;
321
+ const uint32_t tidx = threadIdx.x;
322
+ const uint32_t warp = tidx / THREADS_PER_WARP;
323
+ const uint32_t lane = tidx % THREADS_PER_WARP;
324
+
325
+ Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
326
+
327
+ const uint32_t c = bidn * THREADS_PER_WARP + lane;
328
+ const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
329
+ constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
330
+ for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
331
+ // Each thread sums over NUM_ELT columns.
332
+ Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local, dcolscale_local;
333
+ memset(&dgamma_local, 0, sizeof(dgamma_local));
334
+ memset(&dbeta_local, 0, sizeof(dbeta_local));
335
+ if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
336
+ if (Is_even_cols || col < params.cols) {
337
+ for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
338
+ index_t idx = row * params.cols + col;
339
+
340
+ Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part, dcolscale_part;
341
+ dbeta_part.load_from(params.dbeta_part, idx);
342
+ dgamma_part.load_from(params.dgamma_part, idx);
343
+ if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); }
344
+ #pragma unroll
345
+ for( int it = 0; it < NUM_ELT; it++ ) {
346
+ dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
347
+ dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
348
+ if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; }
349
+ }
350
+ }
351
+ }
352
+ void * smem_gamma = smem_;
353
+ void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
354
+ void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
355
+
356
+ const int write_row = warp;
357
+ const int write_col = lane ^ write_row;
358
+ const int write_idx = write_row * THREADS_PER_WARP + write_col;
359
+
360
+ dgamma_local.store_to(smem_gamma, write_idx);
361
+ dbeta_local.store_to(smem_beta, write_idx);
362
+ if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); }
363
+
364
+ __syncthreads();
365
+
366
+ // It would be probably safe to reuse the first row of smem_beta and smem_gamma
367
+ void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE];
368
+ void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
369
+ void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
370
+
371
+
372
+ // More than one iter iff ROWS_PER_CTA < 32.
373
+ for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
374
+ const int read_row = lane;
375
+ const int read_col = w ^ read_row;
376
+ const int read_idx = read_row * THREADS_PER_WARP + read_col;
377
+
378
+ memset(&dbeta_local, 0, sizeof(dbeta_local));
379
+ memset(&dgamma_local, 0, sizeof(dgamma_local));
380
+ if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
381
+
382
+ // Load beta and gamma transposed
383
+ if(read_row < Kernel_traits::ROWS_PER_CTA){
384
+ dbeta_local.load_from(smem_beta, read_idx);
385
+ dgamma_local.load_from(smem_gamma, read_idx);
386
+ if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); }
387
+ }
388
+
389
+ // Call reducer on the loaded value(s) and convert.
390
+ #pragma unroll
391
+ for( int it = 0; it < NUM_ELT; it++ ) {
392
+ compute_t b_i = dbeta_local.data.elt[it];
393
+ compute_t g_i = dgamma_local.data.elt[it];
394
+ b_i = reducer.allreduce(b_i, sum);
395
+ g_i = reducer.allreduce(g_i, sum);
396
+
397
+ dgamma_local.data.elt[it] = g_i;
398
+ dbeta_local.data.elt[it] = b_i;
399
+ if (Has_colscale) {
400
+ compute_t cs_i = dcolscale_local.data.elt[it];
401
+ cs_i = reducer.allreduce(cs_i, sum);
402
+ dcolscale_local.data.elt[it] = cs_i;
403
+ }
404
+ }
405
+
406
+ // Leader stores the result at the current column.
407
+ if(lane == 0){
408
+ dgamma_local.store_to(smem_gamma_out, w);
409
+ dbeta_local.store_to(smem_beta_out, w);
410
+ if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); }
411
+ }
412
+
413
+ }
414
+
415
+ // All writes done.
416
+ __syncthreads();
417
+
418
+ // Pack and store: 2-wide stores with half the threads.
419
+ if (Is_even_cols || col_out * 2 < params.cols) {
420
+ if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
421
+
422
+ using src_t = typename TypeToVec2<compute_t>::Type;
423
+ using dst_t = typename TypeToVec2<weight_t>::Type;
424
+ Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2, dcolscale_vec2;
425
+ Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2, dcolscale_out2;
426
+
427
+ dgamma_vec2.load_from(smem_gamma_out, lane);
428
+ dbeta_vec2.load_from(smem_beta_out, lane);
429
+ if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); }
430
+ #pragma unroll
431
+ for( int it = 0; it < NUM_ELT; it++ ) {
432
+ dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
433
+ dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
434
+ if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dcolscale_vec2.data.elt[it]); }
435
+ }
436
+ dgamma_out2.store_to(params.dgamma, col_out);
437
+ dbeta_out2.store_to(params.dbeta, col_out);
438
+ if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); }
439
+ }
440
+ }
441
+ }
442
+ }
443
+ } // namespace layer_norm
444
+
445
+ using namespace layer_norm;
446
+
447
+ template<
448
+ typename weight_t,
449
+ typename input_t,
450
+ typename residual_t,
451
+ typename output_t,
452
+ typename compute_t,
453
+ typename index_t,
454
+ int HIDDEN_SIZE,
455
+ int CTAS_PER_ROW,
456
+ int WARPS_M,
457
+ int WARPS_N,
458
+ int BYTES_PER_LDG_MAIN,
459
+ int BYTES_PER_LDG_FINAL
460
+ >
461
+ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
462
+
463
+ using Kernel_traits = Kernel_traits<weight_t,
464
+ input_t,
465
+ residual_t,
466
+ output_t,
467
+ compute_t,
468
+ index_t,
469
+ HIDDEN_SIZE,
470
+ CTAS_PER_ROW,
471
+ WARPS_M,
472
+ WARPS_N,
473
+ BYTES_PER_LDG_MAIN
474
+ >;
475
+ bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
476
+ bool has_colscale = launch_params.params.colscale != nullptr;
477
+ bool has_subset = launch_params.params.x0_subset != nullptr;
478
+ bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
479
+ BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
480
+ BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
481
+ BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
482
+ BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
483
+ auto kernel = &ln_bwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
484
+ if( configure_params ) {
485
+ int ctas_per_sm;
486
+ CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
487
+ &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
488
+ launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
489
+ launch_params.barrier_size = 0;
490
+ launch_params.workspace_bytes = 0;
491
+ if(Kernel_traits::CTAS_PER_ROW > 1) {
492
+ launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
493
+ launch_params.workspace_bytes = launch_params.params.ctas_per_col
494
+ * Kernel_traits::WARPS_M
495
+ * Kernel_traits::CTAS_PER_ROW
496
+ * sizeof(typename Kernel_traits::reduce_t)
497
+ * 2;
498
+ }
499
+ return;
500
+ }
501
+
502
+ if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
503
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
504
+ }
505
+ auto stream = launch_params.stream;
506
+ auto ctas_per_col = launch_params.params.ctas_per_col;
507
+
508
+ if( Kernel_traits::CTAS_PER_ROW == 1 ) {
509
+ kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
510
+ } else {
511
+ dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
512
+ dim3 block(Kernel_traits::THREADS_PER_CTA);
513
+ void *params_ = (void *)&launch_params.params;
514
+ cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
515
+ }
516
+
517
+ using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
518
+ weight_t,
519
+ input_t,
520
+ residual_t,
521
+ output_t,
522
+ compute_t,
523
+ index_t,
524
+ HasColscaleConst,
525
+ 32 * 32, // THREADS_PER_CTA
526
+ BYTES_PER_LDG_FINAL>;
527
+
528
+ auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
529
+ kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
530
+ });
531
+ });
532
+ });
533
+ });
534
+ }
ln_fwd_1024.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
7
+ REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
8
+ REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
9
+ REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
10
+ REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
11
+ REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
12
+ REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
13
+ REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
14
+ REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
15
+ REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
ln_fwd_1280.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
7
+ REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
8
+ REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
9
+ REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
10
+ REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
11
+ REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
12
+ REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
13
+ REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
14
+ REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
15
+ REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
ln_fwd_1536.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
7
+ REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
8
+ REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
9
+ REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
10
+ REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
11
+ REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
12
+ REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
13
+ REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
14
+ REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
15
+ REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
ln_fwd_2048.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
7
+ REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
8
+ REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
9
+ REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
10
+ REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
11
+ REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
12
+ REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
13
+ REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
14
+ REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
15
+ REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
ln_fwd_256.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
7
+ REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
8
+ REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
9
+ REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
10
+ REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
11
+ REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
12
+ REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
13
+ REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
14
+ REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
15
+ REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
ln_fwd_2560.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
7
+ REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
8
+ REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
9
+ REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
10
+ REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
11
+ REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
12
+ REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
13
+ REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
14
+ REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
15
+ REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
ln_fwd_3072.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
7
+ REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
8
+ REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
9
+ REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
10
+ REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
11
+ REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
12
+ REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
13
+ REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
14
+ REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
15
+ REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
ln_fwd_4096.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
7
+ REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
8
+ REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
9
+ REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
10
+ REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
11
+ REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
12
+ REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
13
+ REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
14
+ REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
15
+ REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
ln_fwd_512.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
7
+ REGISTER_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
8
+ REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
9
+ REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
10
+ REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
11
+ REGISTER_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
12
+ REGISTER_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
13
+ REGISTER_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
14
+ REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
15
+ REGISTER_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
ln_fwd_5120.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
7
+ REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
8
+ REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
9
+ REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
10
+ REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
11
+ REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
12
+ REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
13
+ REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
14
+ REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
15
+ REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
ln_fwd_6144.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
7
+ REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
8
+ REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
9
+ REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
10
+ REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
11
+ REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
12
+ REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
13
+ REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
14
+ REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
15
+ REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
ln_fwd_7168.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
7
+ REGISTER_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
8
+ REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
9
+ REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
10
+ REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
11
+ REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
12
+ REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
13
+ REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
14
+ REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
15
+ REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
ln_fwd_768.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
7
+ REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
8
+ REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
9
+ REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
10
+ REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
11
+ REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
12
+ REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
13
+ REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
14
+ REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
15
+ REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
ln_fwd_8192.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_fwd_kernels.cuh"
2
+
3
+ // Create forward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
+
6
+ REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
7
+ REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
8
+ REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
9
+ REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
10
+ REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
11
+ REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
12
+ REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
13
+ REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
14
+ REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
15
+ REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
ln_fwd_kernels.cuh ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef OLD_GENERATOR_PATH
4
+ #include <ATen/CUDAGeneratorImpl.h>
5
+ #else
6
+ #include <ATen/cuda/CUDAGeneratorImpl.h>
7
+ #endif
8
+
9
+ #include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
10
+ #include <curand_kernel.h>
11
+
12
+ #include "ln.h"
13
+ #include "ln_utils.cuh"
14
+ #include "ln_kernel_traits.h"
15
+ #include "static_switch.h"
16
+
17
+ namespace layer_norm {
18
+
19
+ template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
20
+ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
21
+ void ln_fwd_kernel(FwdParams params) {
22
+
23
+ enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
24
+ enum { WARPS_N = Ktraits::WARPS_N };
25
+ enum { WARPS_M = Ktraits::WARPS_M };
26
+ enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
27
+ enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
28
+ enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
29
+ enum { LDGS = Ktraits::LDGS };
30
+ enum { NUM_ELTS = Ktraits::NUM_ELTS };
31
+ enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
32
+
33
+ using input_t = typename Ktraits::input_t;
34
+ using residual_t = typename Ktraits::residual_t;
35
+ using output_t = typename Ktraits::output_t;
36
+ using index_t = typename Ktraits::index_t;
37
+ using compute_t = typename Ktraits::compute_t;
38
+ using mask_t = typename Ktraits::mask_t;
39
+ using Ivec = typename Ktraits::Ivec;
40
+ using Rvec = typename Ktraits::Rvec;
41
+ using Ovec = typename Ktraits::Ovec;
42
+ using Wvec = typename Ktraits::Wvec;
43
+ using Cvec = typename Ktraits::Cvec;
44
+ using Mvec = typename Ktraits::Mvec;
45
+
46
+ using Stats = typename Ktraits::Stats;
47
+ using stats_t = typename Stats::stats_t;
48
+
49
+ const bool has_residual = params.residual != nullptr;
50
+ const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
51
+
52
+ extern __shared__ char smem_[];
53
+
54
+ const index_t tidx = threadIdx.x;
55
+ const index_t bidn = blockIdx.x % CTAS_PER_ROW;
56
+ const index_t bidm = blockIdx.x / CTAS_PER_ROW;
57
+ const index_t lane = tidx % THREADS_PER_WARP;
58
+ const index_t warp = tidx / THREADS_PER_WARP;
59
+ const index_t warp_m = warp / WARPS_N;
60
+ const index_t warp_n = warp % WARPS_N;
61
+
62
+ const index_t r = bidm * ROWS_PER_CTA + warp_m;
63
+ const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
64
+
65
+ Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
66
+
67
+ compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
68
+ compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
69
+
70
+ const input_t *rowscale = static_cast<input_t *>(params.rowscale);
71
+ const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
72
+ const index_t *z_subset = static_cast<index_t *>(params.z_subset);
73
+
74
+ // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
75
+ curandStatePhilox4_32_10_t state;
76
+ if (Is_dropout) {
77
+ auto seeds = at::cuda::philox::unpack(params.philox_args);
78
+ const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;
79
+ curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
80
+ }
81
+
82
+ const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
83
+
84
+ Wvec gamma[LDGS];
85
+ Wvec beta[LDGS];
86
+ Wvec colscale[LDGS];
87
+ index_t idx = c;
88
+ #pragma unroll
89
+ for( int it = 0; it < LDGS; it++ ) {
90
+ if (Is_even_cols || (it < num_valid_ldgs)) {
91
+ gamma[it].load_from(params.gamma, idx);
92
+ if (params.beta != nullptr) {
93
+ beta[it].load_from(params.beta, idx);
94
+ } else {
95
+ beta[it].zero_();
96
+ }
97
+ if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
98
+ idx += VEC_COLS_PER_LDG;
99
+ }
100
+ }
101
+
102
+ for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
103
+ const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
104
+ const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
105
+ const int row_z = !Has_subset ? row + 1 : z_subset[row];
106
+ const bool load_x0 = !Has_subset || row_x0 > 0;
107
+ index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
108
+ index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
109
+ compute_t xf[LDGS * NUM_ELTS];
110
+ #pragma unroll
111
+ for( int it = 0; it < LDGS; it++ ) {
112
+ if (Is_even_cols || (it < num_valid_ldgs)) {
113
+ Ivec x0;
114
+ Rvec residual;
115
+ Rvec x;
116
+ Mvec dmask;
117
+ if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
118
+ if (has_residual) { residual.load_from(params.residual, idx_x); }
119
+ #pragma unroll
120
+ for( int jt = 0; jt < NUM_ELTS; jt++ ) {
121
+ // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
122
+ // the more efficient curand_uniform4.
123
+ compute_t x_ij;
124
+ if (load_x0) {
125
+ mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
126
+ if (Is_dropout) { dmask.data.elt[jt] = keep; }
127
+ compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
128
+ x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
129
+ if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
130
+ x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
131
+ } else {
132
+ x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;
133
+ }
134
+ if (save_x) { x.data.elt[jt] = x_ij; }
135
+ xf[it * NUM_ELTS + jt] = x_ij;
136
+ }
137
+ if (save_x) { x.store_to(params.x, idx_x); }
138
+ if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); }
139
+ idx_x += VEC_COLS_PER_LDG;
140
+ idx_x0 += VEC_COLS_PER_LDG;
141
+ }
142
+ }
143
+
144
+ static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now");
145
+ const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
146
+ const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
147
+ const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
148
+ auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
149
+ // Need to convert to int, otherwise the subtraction will wrap around.
150
+ const index_t valid_partial_vecs_in_warp =
151
+ std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
152
+ int(THREADS_PER_WARP));
153
+ return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;
154
+ };
155
+ stats_t s = stats.template compute<Is_even_cols>(
156
+ xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS
157
+ );
158
+
159
+ compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
160
+ compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
161
+
162
+ if( bidn == 0 && warp_n == 0 && lane == 0 ) {
163
+ mu_ptr[row] = mu;
164
+ }
165
+
166
+ compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));
167
+
168
+ if( bidn == 0 && warp_n == 0 && lane == 0 ) {
169
+ rs_ptr[row] = rs;
170
+ }
171
+
172
+ const bool save_z = !Has_subset || row_z > 0;
173
+ if (save_z) {
174
+ index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c;
175
+ #pragma unroll
176
+ for( int it = 0; it < LDGS; it++ ) {
177
+ if (Is_even_cols || (it < num_valid_ldgs)) {
178
+ Ovec z;
179
+ #pragma unroll
180
+ for( int jt = 0; jt < NUM_ELTS; jt++ ) {
181
+ compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));
182
+ compute_t g_ij = gamma[it].data.elt[jt];
183
+ compute_t b_ij = beta[it].data.elt[jt];
184
+ z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
185
+ }
186
+ z.store_to(params.z, idx_z);
187
+ idx_z += VEC_COLS_PER_LDG;
188
+ }
189
+ }
190
+ }
191
+
192
+ }
193
+ }
194
+
195
+ } // namespace layer_norm
196
+
197
+ using namespace layer_norm;
198
+
199
+ template<
200
+ typename weight_t,
201
+ typename input_t,
202
+ typename residual_t,
203
+ typename output_t,
204
+ typename compute_t,
205
+ typename index_t,
206
+ int HIDDEN_SIZE,
207
+ int CTAS_PER_ROW,
208
+ int WARPS_M,
209
+ int WARPS_N,
210
+ int BYTES_PER_LDG
211
+ >
212
+ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
213
+
214
+ using Kernel_traits = Kernel_traits<weight_t,
215
+ input_t,
216
+ residual_t,
217
+ output_t,
218
+ compute_t,
219
+ index_t,
220
+ HIDDEN_SIZE,
221
+ CTAS_PER_ROW,
222
+ WARPS_M,
223
+ WARPS_N,
224
+ BYTES_PER_LDG
225
+ >;
226
+ bool has_colscale = launch_params.params.colscale != nullptr;
227
+ bool has_subset = launch_params.params.x0_subset != nullptr;
228
+ bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
229
+ BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
230
+ BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
231
+ BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
232
+ BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
233
+ auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
234
+ if( configure_params ) {
235
+ int ctas_per_sm;
236
+ CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
237
+ &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
238
+ launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
239
+ const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
240
+ launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
241
+ launch_params.barrier_size = 0;
242
+ launch_params.workspace_bytes = 0;
243
+ if(Kernel_traits::CTAS_PER_ROW > 1) {
244
+ launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
245
+ launch_params.workspace_bytes = launch_params.params.ctas_per_col
246
+ * Kernel_traits::WARPS_M
247
+ * Kernel_traits::CTAS_PER_ROW
248
+ * sizeof(typename Kernel_traits::Stats::stats_t)
249
+ * 2;
250
+ }
251
+ return;
252
+ }
253
+
254
+ if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
255
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
256
+ }
257
+ auto stream = launch_params.stream;
258
+ auto ctas_per_col = launch_params.params.ctas_per_col;
259
+
260
+ if( Kernel_traits::CTAS_PER_ROW == 1 ) {
261
+ kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
262
+ } else {
263
+ dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
264
+ dim3 block(Kernel_traits::THREADS_PER_CTA);
265
+ void *params_ = (void *)&launch_params.params;
266
+ cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
267
+ }
268
+ });
269
+ });
270
+ });
271
+ });
272
+ }
ln_kernel_traits.h ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
4
+
5
+ namespace layer_norm {
6
+ template<
7
+ uint32_t HIDDEN_SIZE_,
8
+ typename weight_t_,
9
+ typename input_t_,
10
+ typename residual_t_,
11
+ typename output_t_,
12
+ typename compute_t_,
13
+ typename index_t_,
14
+ uint32_t THREADS_PER_CTA_
15
+ >
16
+ struct Kernel_traits_base {
17
+
18
+ using weight_t = weight_t_;
19
+ using input_t = input_t_;
20
+ using residual_t = residual_t_;
21
+ using output_t = output_t_;
22
+ using compute_t = compute_t_;
23
+ using index_t = index_t_;
24
+
25
+ enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
26
+ enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
27
+ enum { THREADS_PER_WARP = 32 };
28
+
29
+ };
30
+
31
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
32
+
33
+ template<
34
+ uint32_t HIDDEN_SIZE_,
35
+ typename weight_t_,
36
+ typename input_t_,
37
+ typename residual_t_,
38
+ typename output_t_,
39
+ typename compute_t_,
40
+ typename index_t_,
41
+ bool Has_colscale,
42
+ uint32_t THREADS_PER_CTA_,
43
+ uint32_t BYTES_PER_LDG_,
44
+ typename Base = Kernel_traits_base<HIDDEN_SIZE_,
45
+ weight_t_,
46
+ input_t_,
47
+ residual_t_,
48
+ output_t_,
49
+ compute_t_,
50
+ index_t_,
51
+ THREADS_PER_CTA_>
52
+ >
53
+ struct Kernel_traits_finalize : public Base {
54
+ enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
55
+ static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
56
+ // Bytes per global load from the input.
57
+ enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
58
+ // Number of elements fetched by a global load.
59
+ enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
60
+ // Bytes per global store of the weights.
61
+ enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
62
+ static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
63
+ static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
64
+ // The total number of BYTES_PER_LDG-wide words in a hidden vector.
65
+ enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
66
+ static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
67
+
68
+ // Shared memory size to transpose the CTA result.
69
+ enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
70
+ // Shared memory size to coalsece the CTA result.
71
+ enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
72
+ // Shared memory requirement per CTA.
73
+ static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2;
74
+ enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT };
75
+
76
+ // The type of the reducer.
77
+ using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
78
+
79
+ // Condition for the whole CTA to participate in syncthreads.
80
+ static_assert(COLS % Base::THREADS_PER_WARP == 0);
81
+ enum { CTAS = COLS / Base::THREADS_PER_WARP };
82
+ };
83
+
84
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
85
+
86
+
87
+ template<
88
+ typename weight_t_,
89
+ typename input_t_,
90
+ typename residual_t_,
91
+ typename output_t_,
92
+ typename compute_t_,
93
+ typename index_t_,
94
+ uint32_t HIDDEN_SIZE_,
95
+ uint32_t CTAS_PER_ROW_,
96
+ uint32_t WARPS_M_,
97
+ uint32_t WARPS_N_,
98
+ uint32_t BYTES_PER_LDG_ = 16,
99
+ typename Base = Kernel_traits_base<
100
+ HIDDEN_SIZE_,
101
+ weight_t_,
102
+ input_t_,
103
+ residual_t_,
104
+ output_t_,
105
+ compute_t_,
106
+ index_t_,
107
+ WARPS_M_*WARPS_N_*THREADS_PER_WARP
108
+ >
109
+ >
110
+ struct Kernel_traits : public Base {
111
+
112
+ using input_t = typename Base::input_t;
113
+ using residual_t = typename Base::residual_t;
114
+ using weight_t = typename Base::weight_t;
115
+ using compute_t = typename Base::compute_t;
116
+ using output_t = typename Base::output_t;
117
+ using index_t = typename Base::index_t;
118
+ // using mask_t = unsigned char;
119
+ using mask_t = bool;
120
+
121
+ enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
122
+ enum { WARPS_M = WARPS_M_ };
123
+ enum { WARPS_N = WARPS_N_ };
124
+ enum { COLS = HIDDEN_SIZE_ };
125
+ enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
126
+ enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
127
+ enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
128
+
129
+ enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
130
+ enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
131
+ enum { ROWS_PER_CTA = WARPS_M };
132
+
133
+ enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
134
+ enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
135
+ // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
136
+ enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
137
+ static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
138
+
139
+ using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
140
+ using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
141
+
142
+ enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
143
+ enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
144
+
145
+ using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
146
+ using Rvec = layer_norm::Vec<residual_t, NUM_ELTS>;
147
+ using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
148
+ using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
149
+ using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
150
+ using Mvec = layer_norm::Vec<mask_t, NUM_ELTS>;
151
+ enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
152
+
153
+ // Assume that each thread can handle the same number of elements in the output and weights as in the input.
154
+ static_assert(sizeof(input_t) == sizeof(output_t));
155
+ static_assert(sizeof(input_t) <= sizeof(residual_t));
156
+ // The number of columns fetched per load from input: one per thread.
157
+ enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
158
+ // The total number of vectorized loads/stores per hidden vector.
159
+ enum { VEC_COLS = COLS / ELTS_PER_LDG };
160
+ // The number of loads per thread for the input.
161
+ enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
162
+ static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
163
+ //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
164
+
165
+ using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
166
+ enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
167
+
168
+ };
169
+
170
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
171
+
172
+ } // namespace layer_norm
ln_parallel_bwd_1024.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
ln_parallel_bwd_1280.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
ln_parallel_bwd_1536.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
ln_parallel_bwd_2048.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
ln_parallel_bwd_256.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
+ REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
ln_parallel_bwd_2560.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
ln_parallel_bwd_3072.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
+ REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
ln_parallel_bwd_4096.cu ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ // Use 8 warps otherwise there's a lot of register spilling
7
+
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
16
+ REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
17
+ REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
ln_parallel_bwd_512.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
+ REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
ln_parallel_bwd_5120.cu ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ // Use 8 warps otherwise there's a lot of register spilling
7
+
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
16
+ REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
17
+ REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
ln_parallel_bwd_6144.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
7
+ REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
ln_parallel_bwd_7168.cu ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ln_parallel_residual_bwd_kernels.cuh"
2
+
3
+ // Create backward launch function and register. Macro signature:
4
+ // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
+
6
+ REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
7
+ REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
8
+ REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
9
+ REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
10
+ REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
11
+ REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
12
+ REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
13
+ REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
14
+ REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
15
+ REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);