wangzihan99 commited on
Commit
89a2cd3
·
1 Parent(s): 5bfdae9

Add fused ApplyRoPE and RMSNorm kernels written in OpenAI Triton.

Browse files
Files changed (4) hide show
  1. config.json +1 -0
  2. configuration_qwen.py +2 -0
  3. modeling_qwen.py +35 -2
  4. triton_kernels.py +147 -0
config.json CHANGED
@@ -44,6 +44,7 @@
44
  "use_cache": true,
45
  "use_dynamic_ntk": true,
46
  "use_flash_attn": "auto",
 
47
  "use_logn_attn": true,
48
  "vocab_size": 151936
49
  }
 
44
  "use_cache": true,
45
  "use_dynamic_ntk": true,
46
  "use_flash_attn": "auto",
47
+ "use_triton": "auto",
48
  "use_logn_attn": true,
49
  "vocab_size": 151936
50
  }
configuration_qwen.py CHANGED
@@ -32,6 +32,7 @@ class QWenConfig(PretrainedConfig):
32
  use_dynamic_ntk=True,
33
  use_logn_attn=True,
34
  use_flash_attn="auto",
 
35
  intermediate_size=22016,
36
  no_bias=True,
37
  tie_word_embeddings=False,
@@ -61,6 +62,7 @@ class QWenConfig(PretrainedConfig):
61
  self.use_dynamic_ntk = use_dynamic_ntk
62
  self.use_logn_attn = use_logn_attn
63
  self.use_flash_attn = use_flash_attn
 
64
  self.no_bias = no_bias
65
  self.use_cache_quantization = use_cache_quantization
66
  self.use_cache_kernel = use_cache_kernel
 
32
  use_dynamic_ntk=True,
33
  use_logn_attn=True,
34
  use_flash_attn="auto",
35
+ use_triton="auto",
36
  intermediate_size=22016,
37
  no_bias=True,
38
  tie_word_embeddings=False,
 
62
  self.use_dynamic_ntk = use_dynamic_ntk
63
  self.use_logn_attn = use_logn_attn
64
  self.use_flash_attn = use_flash_attn
65
+ self.use_triton = use_triton
66
  self.no_bias = no_bias
67
  self.use_cache_quantization = use_cache_quantization
68
  self.use_cache_kernel = use_cache_kernel
modeling_qwen.py CHANGED
@@ -36,7 +36,7 @@ except ImportError:
36
  from torch import nn
37
 
38
  SUPPORT_CUDA = torch.cuda.is_available()
39
- SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
40
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
41
  SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
42
 
@@ -77,6 +77,7 @@ We detect you have activated flash attention support, but running model computat
77
  """
78
 
79
  apply_rotary_emb_func = None
 
80
  rms_norm = None
81
  flash_attn_unpadded_func = None
82
 
@@ -116,6 +117,30 @@ def _import_flash_attn():
116
  "https://github.com/Dao-AILab/flash-attention"
117
  )
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def quantize_cache_v(fdata, bits, qmax, qmin):
120
  # b, s, head, h-dim->b, head, s, h-dim
121
  qtype = torch.uint8
@@ -1052,6 +1077,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
1052
  if config.use_flash_attn:
1053
  _import_flash_attn()
1054
 
 
 
 
 
 
 
1055
  self.transformer = QWenModel(config)
1056
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1057
 
@@ -1412,7 +1443,9 @@ def _rotate_half(x):
1412
 
1413
  def apply_rotary_pos_emb(t, freqs):
1414
  cos, sin = freqs
1415
- if apply_rotary_emb_func is not None and t.is_cuda:
 
 
1416
  t_ = t.float()
1417
  cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
1418
  sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
 
36
  from torch import nn
37
 
38
  SUPPORT_CUDA = torch.cuda.is_available()
39
+ SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 8
40
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
41
  SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
42
 
 
77
  """
78
 
79
  apply_rotary_emb_func = None
80
+ apply_rotary_emb_func_triton = None
81
  rms_norm = None
82
  flash_attn_unpadded_func = None
83
 
 
117
  "https://github.com/Dao-AILab/flash-attention"
118
  )
119
 
120
+ def _import_triton():
121
+ global apply_rotary_emb_func_triton, rms_norm
122
+ try:
123
+ from .triton_kernels import triton, apply_rotary_emb as __apply_rotary_emb, rms_norm as __rms_norm
124
+ if apply_rotary_emb_func is not None:
125
+ logger.warn(
126
+ "rotary kernel imported from flash_attn is replaced by Triton kernel."
127
+ )
128
+ apply_rotary_emb_func_triton = __apply_rotary_emb
129
+ if rms_norm is not None:
130
+ logger.warn(
131
+ "rms_norm kernel imported from flash_attn is replaced by Triton kernel."
132
+ )
133
+ rms_norm = __rms_norm
134
+ except ImportError:
135
+ logger.warn("Warning: Failed to import Triton kernels.")
136
+ return
137
+
138
+ if int(triton.__version__.split(".")[1]) == 0:
139
+ logger.warn(
140
+ "Triton 2.0 is detected in your environment. It is recommended that you upgrade to Triton 2.1 by "
141
+ "`pip install triton==2.1` for better performance if you do not use TorchInductor in PyTorch 2.0."
142
+ )
143
+
144
  def quantize_cache_v(fdata, bits, qmax, qmin):
145
  # b, s, head, h-dim->b, head, s, h-dim
146
  qtype = torch.uint8
 
1077
  if config.use_flash_attn:
1078
  _import_flash_attn()
1079
 
1080
+ if config.use_triton == "auto":
1081
+ config.use_triton = SUPPORT_TORCH2
1082
+ if config.use_triton:
1083
+ logger.warn("Try importing Triton kernels for faster inference...")
1084
+ _import_triton()
1085
+
1086
  self.transformer = QWenModel(config)
1087
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1088
 
 
1443
 
1444
  def apply_rotary_pos_emb(t, freqs):
1445
  cos, sin = freqs
1446
+ if apply_rotary_emb_func_triton is not None and t.is_cuda:
1447
+ return apply_rotary_emb_func_triton(t, cos, sin)
1448
+ elif apply_rotary_emb_func is not None and t.is_cuda:
1449
  t_ = t.float()
1450
  cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
1451
  sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
triton_kernels.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Callable, Dict, Hashable, Tuple
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+ from torch.autograd.function import Function, FunctionCtx
8
+ from triton.compiler import CompiledKernel
9
+ from triton.runtime import KernelInterface
10
+ from triton.runtime.autotuner import Autotuner
11
+
12
+ try:
13
+ import triton.language.math as tlmath # Triton 2.1
14
+ except ImportError:
15
+ import triton.language.libdevice as tlmath # Triton 2.0
16
+
17
+
18
+ class TritonKernel:
19
+ def __init__(
20
+ self,
21
+ kernel_fn_: KernelInterface,
22
+ grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]],
23
+ ) -> None:
24
+ self.kernel_fn_ = kernel_fn_
25
+ self.grid_fn_ = grid_fn
26
+ self.kernel_cache_: Dict[Hashable, CompiledKernel] = {}
27
+
28
+ def run(self, *args, **kwargs):
29
+ # Set current device
30
+ input_device = args[0].device
31
+ prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device()
32
+ if input_device.index != cur_dev_idx:
33
+ prev_dev_idx = cur_dev_idx
34
+ torch.cuda.set_device(input_device.index)
35
+
36
+ # Compute grid
37
+ grid = self.grid_fn_(args)
38
+
39
+ # Use cached kernel if possible
40
+ kernel_key = (input_device,)
41
+ if isinstance(self.kernel_fn_, Autotuner):
42
+ kernel_key += tuple(args[ki] for ki in self.kernel_fn_.key_idx)
43
+ else:
44
+ kernel_key += tuple(kwargs.items())
45
+ if kernel_key in self.kernel_cache_:
46
+ kernel = self.kernel_cache_[kernel_key]
47
+ kernel[grid](*args)
48
+ return
49
+
50
+ # Compile new kernel
51
+ if isinstance(self.kernel_fn_, Autotuner):
52
+ kernel = self.kernel_fn_[grid](*args)
53
+ else:
54
+ kernel = self.kernel_fn_[grid](*args, **kwargs)
55
+
56
+ # Store kernel
57
+ self.kernel_cache_[kernel_key] = kernel
58
+
59
+ # Restore previous device
60
+ torch.cuda.set_device(prev_dev_idx)
61
+
62
+
63
+ @triton.jit
64
+ def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr):
65
+ batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2)
66
+ seq_len, num_heads = tl.num_programs(1), tl.num_programs(2)
67
+ block_idx = tl.arange(0, HEAD_DIM)
68
+ x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM
69
+ x = tl.load(X + x_base_idx + block_idx)
70
+ freq_idx = tok_idx * HEAD_DIM + block_idx
71
+ cos = tl.load(Cos + freq_idx)
72
+ rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM
73
+ x_rot = tl.load(X + x_base_idx + rot_idx)
74
+ x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot)
75
+ sin = tl.load(Sin + freq_idx)
76
+ y_idx = (
77
+ (batch_idx * seq_len + tok_idx) * num_heads + head_idx
78
+ ) * HEAD_DIM + block_idx
79
+ y = x * cos + x_rot * sin
80
+ tl.store(Y + y_idx, y)
81
+
82
+
83
+ apply_rope_fwd_kernel = TritonKernel(
84
+ _apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3])
85
+ )
86
+
87
+
88
+ class ApplyRotaryEmb(Function):
89
+ @staticmethod
90
+ def forward(
91
+ ctx: FunctionCtx, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
92
+ ):
93
+ y = torch.empty(x.shape, dtype=x.dtype, device=x.device)
94
+ apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1))
95
+ return y
96
+
97
+
98
+ def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
99
+ return ApplyRotaryEmb.apply(x, cos, sin)
100
+
101
+
102
+ @triton.autotune(
103
+ configs=[
104
+ triton.Config({"BLOCK_SIZE": 4096}),
105
+ triton.Config({"BLOCK_SIZE": 2048}),
106
+ triton.Config({"BLOCK_SIZE": 1024}),
107
+ ],
108
+ key=[],
109
+ )
110
+ @triton.jit
111
+ def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr):
112
+ tok_idx = tl.program_id(0)
113
+
114
+ mean_sq = tl.zeros([BLOCK_SIZE], tl.float32)
115
+ for offset in range(0, hidden_dim, BLOCK_SIZE):
116
+ dim_idx = offset + tl.arange(0, BLOCK_SIZE)
117
+ x = tl.load(
118
+ X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0
119
+ ).to(tl.float32)
120
+ mean_sq += x * x / hidden_dim
121
+ rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps)
122
+
123
+ for offset in range(0, hidden_dim, BLOCK_SIZE):
124
+ dim_idx = offset + tl.arange(0, BLOCK_SIZE)
125
+ dim_mask = dim_idx < hidden_dim
126
+ hidden_idx = tok_idx * hidden_dim + dim_idx
127
+ x = tl.load(X + hidden_idx, mask=dim_mask, other=0)
128
+ w = tl.load(W + dim_idx, mask=dim_mask, other=0)
129
+ y = x * rrms * w
130
+ tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask)
131
+
132
+
133
+ rms_norm_fwd_kernel = TritonKernel(
134
+ _rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1)
135
+ )
136
+
137
+
138
+ class RMSNorm(torch.autograd.Function):
139
+ @staticmethod
140
+ def forward(ctx: FunctionCtx, x: torch.Tensor, w: torch.Tensor, eps: float):
141
+ y = torch.empty_like(x)
142
+ rms_norm_fwd_kernel.run(x, w, y, eps, x.size(-1))
143
+ return y
144
+
145
+
146
+ def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
147
+ return RMSNorm.apply(x, weight, eps)