wangzihan99 commited on
Commit
5b354c8
·
1 Parent(s): 74a1327

Improve performance witih Triton 2.0 and adapt to latest Qwen releases.

Browse files
Files changed (2) hide show
  1. modeling_qwen.py +13 -14
  2. triton_kernels.py +16 -48
modeling_qwen.py CHANGED
@@ -78,6 +78,7 @@ We detect you have activated flash attention support, but running model computat
78
  apply_rotary_emb_func = None
79
  apply_rotary_emb_func_triton = None
80
  rms_norm = None
 
81
  flash_attn_unpadded_func = None
82
  flash_attn_func = None
83
 
@@ -122,28 +123,22 @@ def _import_flash_attn():
122
  )
123
 
124
  def _import_triton():
125
- global apply_rotary_emb_func_triton, rms_norm
126
  try:
127
- from .triton_kernels import triton, apply_rotary_emb as __apply_rotary_emb, rms_norm as __rms_norm
128
  if apply_rotary_emb_func is not None:
129
  logger.warn(
130
- "rotary kernel imported from flash_attn is replaced by Triton kernel."
131
  )
132
  apply_rotary_emb_func_triton = __apply_rotary_emb
133
  if rms_norm is not None:
134
  logger.warn(
135
- "rms_norm kernel imported from flash_attn is replaced by Triton kernel."
136
  )
137
- rms_norm = __rms_norm
138
  except ImportError:
139
  logger.warn("Warning: Failed to import Triton kernels.")
140
  return
141
-
142
- if int(triton.__version__.split(".")[1]) == 0:
143
- logger.warn(
144
- "Triton 2.0 is detected in your environment. It is recommended that you upgrade to Triton 2.1 by "
145
- "`pip install triton==2.1` for better performance if you do not use TorchInductor in PyTorch 2.0."
146
- )
147
 
148
  def quantize_cache_v(fdata, bits, qmax, qmin):
149
  # b, s, head, h-dim->b, head, s, h-dim
@@ -1004,9 +999,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
1004
  _import_flash_attn()
1005
 
1006
  if config.use_triton == "auto":
 
1007
  config.use_triton = SUPPORT_TORCH2
1008
  if config.use_triton:
1009
- logger.warn("Try importing Triton kernels for faster inference...")
1010
  _import_triton()
1011
 
1012
  self.transformer = QWenModel(config)
@@ -1366,7 +1361,9 @@ def apply_rotary_pos_emb(t, freqs):
1366
  rot_dim = freqs[0].shape[-1]
1367
  cos, sin = freqs
1368
  t_float = t.float()
1369
- if apply_rotary_emb_func is not None and t.is_cuda:
 
 
1370
  # apply_rotary_emb in flash_attn requires cos/sin to be of
1371
  # shape (seqlen, rotary_dim / 2) and apply rotary embedding
1372
  # to the first rotary_dim of the input
@@ -1389,7 +1386,9 @@ class RMSNorm(torch.nn.Module):
1389
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1390
 
1391
  def forward(self, x):
1392
- if rms_norm is not None and x.is_cuda:
 
 
1393
  return rms_norm(x, self.weight, self.eps)
1394
  else:
1395
  output = self._norm(x.float()).type_as(x)
 
78
  apply_rotary_emb_func = None
79
  apply_rotary_emb_func_triton = None
80
  rms_norm = None
81
+ rms_norm_triton = None
82
  flash_attn_unpadded_func = None
83
  flash_attn_func = None
84
 
 
123
  )
124
 
125
  def _import_triton():
126
+ global apply_rotary_emb_func_triton, rms_norm_triton
127
  try:
128
+ from .triton_kernels import apply_rotary_emb as __apply_rotary_emb, rms_norm as __rms_norm
129
  if apply_rotary_emb_func is not None:
130
  logger.warn(
131
+ "Using Triton rotary kernel instead of flash_attn for inference."
132
  )
133
  apply_rotary_emb_func_triton = __apply_rotary_emb
134
  if rms_norm is not None:
135
  logger.warn(
136
+ "Using Triton rms_norm kernel instead of flash_attn for inference."
137
  )
138
+ rms_norm_triton = __rms_norm
139
  except ImportError:
140
  logger.warn("Warning: Failed to import Triton kernels.")
141
  return
 
 
 
 
 
 
142
 
143
  def quantize_cache_v(fdata, bits, qmax, qmin):
144
  # b, s, head, h-dim->b, head, s, h-dim
 
999
  _import_flash_attn()
1000
 
1001
  if config.use_triton == "auto":
1002
+ logger.warn("Try importing Triton kernels for faster inference...")
1003
  config.use_triton = SUPPORT_TORCH2
1004
  if config.use_triton:
 
1005
  _import_triton()
1006
 
1007
  self.transformer = QWenModel(config)
 
1361
  rot_dim = freqs[0].shape[-1]
1362
  cos, sin = freqs
1363
  t_float = t.float()
1364
+ if apply_rotary_emb_func_triton is not None and t.is_cuda and (not t.requires_grad):
1365
+ return apply_rotary_emb_func_triton(t, cos, sin)
1366
+ elif apply_rotary_emb_func is not None and t.is_cuda:
1367
  # apply_rotary_emb in flash_attn requires cos/sin to be of
1368
  # shape (seqlen, rotary_dim / 2) and apply rotary embedding
1369
  # to the first rotary_dim of the input
 
1386
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1387
 
1388
  def forward(self, x):
1389
+ if rms_norm_triton is not None and x.is_cuda and (not x.requires_grad):
1390
+ return rms_norm_triton(x, self.weight, self.eps)
1391
+ elif rms_norm is not None and x.is_cuda:
1392
  return rms_norm(x, self.weight, self.eps)
1393
  else:
1394
  output = self._norm(x.float()).type_as(x)
triton_kernels.py CHANGED
@@ -1,13 +1,10 @@
1
- from functools import partial
2
  from typing import Any, Callable, Dict, Hashable, Tuple
3
 
4
  import torch
5
  import triton
6
  import triton.language as tl
7
- from torch.autograd.function import Function, FunctionCtx
8
  from triton.compiler import CompiledKernel
9
- from triton.runtime import KernelInterface
10
- from triton.runtime.autotuner import Autotuner
11
 
12
  try:
13
  import triton.language.math as tlmath # Triton 2.1
@@ -18,10 +15,10 @@ except ImportError:
18
  class TritonKernel:
19
  def __init__(
20
  self,
21
- kernel_fn_: KernelInterface,
22
  grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]],
23
  ) -> None:
24
- self.kernel_fn_ = kernel_fn_
25
  self.grid_fn_ = grid_fn
26
  self.kernel_cache_: Dict[Hashable, CompiledKernel] = {}
27
 
@@ -37,24 +34,14 @@ class TritonKernel:
37
  grid = self.grid_fn_(args)
38
 
39
  # Use cached kernel if possible
40
- kernel_key = (input_device,)
41
- if isinstance(self.kernel_fn_, Autotuner):
42
- kernel_key += tuple(args[ki] for ki in self.kernel_fn_.key_idx)
43
- else:
44
- kernel_key += tuple(kwargs.items())
45
  if kernel_key in self.kernel_cache_:
46
  kernel = self.kernel_cache_[kernel_key]
47
  kernel[grid](*args)
48
- return
49
-
50
- # Compile new kernel
51
- if isinstance(self.kernel_fn_, Autotuner):
52
- kernel = self.kernel_fn_[grid](*args)
53
  else:
 
54
  kernel = self.kernel_fn_[grid](*args, **kwargs)
55
-
56
- # Store kernel
57
- self.kernel_cache_[kernel_key] = kernel
58
 
59
  # Restore previous device
60
  torch.cuda.set_device(prev_dev_idx)
@@ -77,7 +64,7 @@ def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr):
77
  (batch_idx * seq_len + tok_idx) * num_heads + head_idx
78
  ) * HEAD_DIM + block_idx
79
  y = x * cos + x_rot * sin
80
- tl.store(Y + y_idx, y)
81
 
82
 
83
  apply_rope_fwd_kernel = TritonKernel(
@@ -85,28 +72,12 @@ apply_rope_fwd_kernel = TritonKernel(
85
  )
86
 
87
 
88
- class ApplyRotaryEmb(Function):
89
- @staticmethod
90
- def forward(
91
- ctx: FunctionCtx, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
92
- ):
93
- y = torch.empty(x.shape, dtype=x.dtype, device=x.device)
94
- apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1))
95
- return y
96
-
97
-
98
  def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
99
- return ApplyRotaryEmb.apply(x, cos, sin)
 
 
100
 
101
 
102
- @triton.autotune(
103
- configs=[
104
- triton.Config({"BLOCK_SIZE": 4096}),
105
- triton.Config({"BLOCK_SIZE": 2048}),
106
- triton.Config({"BLOCK_SIZE": 1024}),
107
- ],
108
- key=[],
109
- )
110
  @triton.jit
111
  def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr):
112
  tok_idx = tl.program_id(0)
@@ -135,13 +106,10 @@ rms_norm_fwd_kernel = TritonKernel(
135
  )
136
 
137
 
138
- class RMSNorm(torch.autograd.Function):
139
- @staticmethod
140
- def forward(ctx: FunctionCtx, x: torch.Tensor, w: torch.Tensor, eps: float):
141
- y = torch.empty_like(x)
142
- rms_norm_fwd_kernel.run(x, w, y, eps, x.size(-1))
143
- return y
144
-
145
-
146
  def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
147
- return RMSNorm.apply(x, weight, eps)
 
 
 
 
 
 
 
1
  from typing import Any, Callable, Dict, Hashable, Tuple
2
 
3
  import torch
4
  import triton
5
  import triton.language as tl
 
6
  from triton.compiler import CompiledKernel
7
+ from triton.runtime import JITFunction
 
8
 
9
  try:
10
  import triton.language.math as tlmath # Triton 2.1
 
15
  class TritonKernel:
16
  def __init__(
17
  self,
18
+ kernel_fn: JITFunction,
19
  grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]],
20
  ) -> None:
21
+ self.kernel_fn_ = kernel_fn
22
  self.grid_fn_ = grid_fn
23
  self.kernel_cache_: Dict[Hashable, CompiledKernel] = {}
24
 
 
34
  grid = self.grid_fn_(args)
35
 
36
  # Use cached kernel if possible
37
+ kernel_key = (input_device,) + tuple(kwargs.items())
 
 
 
 
38
  if kernel_key in self.kernel_cache_:
39
  kernel = self.kernel_cache_[kernel_key]
40
  kernel[grid](*args)
 
 
 
 
 
41
  else:
42
+ # Compile and store new kernel
43
  kernel = self.kernel_fn_[grid](*args, **kwargs)
44
+ self.kernel_cache_[kernel_key] = kernel
 
 
45
 
46
  # Restore previous device
47
  torch.cuda.set_device(prev_dev_idx)
 
64
  (batch_idx * seq_len + tok_idx) * num_heads + head_idx
65
  ) * HEAD_DIM + block_idx
66
  y = x * cos + x_rot * sin
67
+ tl.store(Y + y_idx, y.to(Y.dtype.element_ty))
68
 
69
 
70
  apply_rope_fwd_kernel = TritonKernel(
 
72
  )
73
 
74
 
 
 
 
 
 
 
 
 
 
 
75
  def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
76
+ y = torch.empty(x.shape, dtype=x.dtype, device=x.device)
77
+ apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1))
78
+ return y
79
 
80
 
 
 
 
 
 
 
 
 
81
  @triton.jit
82
  def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr):
83
  tok_idx = tl.program_id(0)
 
106
  )
107
 
108
 
 
 
 
 
 
 
 
 
109
  def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
110
+ y = torch.empty_like(x)
111
+ hidden_dim = x.size(-1)
112
+ rms_norm_fwd_kernel.run(
113
+ x, weight, y, eps, hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim)
114
+ )
115
+ return y