Commit ·
89a2cd3
1
Parent(s): 5bfdae9
Add fused ApplyRoPE and RMSNorm kernels written in OpenAI Triton.
Browse files- config.json +1 -0
- configuration_qwen.py +2 -0
- modeling_qwen.py +35 -2
- triton_kernels.py +147 -0
config.json
CHANGED
|
@@ -44,6 +44,7 @@
|
|
| 44 |
"use_cache": true,
|
| 45 |
"use_dynamic_ntk": true,
|
| 46 |
"use_flash_attn": "auto",
|
|
|
|
| 47 |
"use_logn_attn": true,
|
| 48 |
"vocab_size": 151936
|
| 49 |
}
|
|
|
|
| 44 |
"use_cache": true,
|
| 45 |
"use_dynamic_ntk": true,
|
| 46 |
"use_flash_attn": "auto",
|
| 47 |
+
"use_triton": "auto",
|
| 48 |
"use_logn_attn": true,
|
| 49 |
"vocab_size": 151936
|
| 50 |
}
|
configuration_qwen.py
CHANGED
|
@@ -32,6 +32,7 @@ class QWenConfig(PretrainedConfig):
|
|
| 32 |
use_dynamic_ntk=True,
|
| 33 |
use_logn_attn=True,
|
| 34 |
use_flash_attn="auto",
|
|
|
|
| 35 |
intermediate_size=22016,
|
| 36 |
no_bias=True,
|
| 37 |
tie_word_embeddings=False,
|
|
@@ -61,6 +62,7 @@ class QWenConfig(PretrainedConfig):
|
|
| 61 |
self.use_dynamic_ntk = use_dynamic_ntk
|
| 62 |
self.use_logn_attn = use_logn_attn
|
| 63 |
self.use_flash_attn = use_flash_attn
|
|
|
|
| 64 |
self.no_bias = no_bias
|
| 65 |
self.use_cache_quantization = use_cache_quantization
|
| 66 |
self.use_cache_kernel = use_cache_kernel
|
|
|
|
| 32 |
use_dynamic_ntk=True,
|
| 33 |
use_logn_attn=True,
|
| 34 |
use_flash_attn="auto",
|
| 35 |
+
use_triton="auto",
|
| 36 |
intermediate_size=22016,
|
| 37 |
no_bias=True,
|
| 38 |
tie_word_embeddings=False,
|
|
|
|
| 62 |
self.use_dynamic_ntk = use_dynamic_ntk
|
| 63 |
self.use_logn_attn = use_logn_attn
|
| 64 |
self.use_flash_attn = use_flash_attn
|
| 65 |
+
self.use_triton = use_triton
|
| 66 |
self.no_bias = no_bias
|
| 67 |
self.use_cache_quantization = use_cache_quantization
|
| 68 |
self.use_cache_kernel = use_cache_kernel
|
modeling_qwen.py
CHANGED
|
@@ -36,7 +36,7 @@ except ImportError:
|
|
| 36 |
from torch import nn
|
| 37 |
|
| 38 |
SUPPORT_CUDA = torch.cuda.is_available()
|
| 39 |
-
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.
|
| 40 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
| 41 |
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
| 42 |
|
|
@@ -77,6 +77,7 @@ We detect you have activated flash attention support, but running model computat
|
|
| 77 |
"""
|
| 78 |
|
| 79 |
apply_rotary_emb_func = None
|
|
|
|
| 80 |
rms_norm = None
|
| 81 |
flash_attn_unpadded_func = None
|
| 82 |
|
|
@@ -116,6 +117,30 @@ def _import_flash_attn():
|
|
| 116 |
"https://github.com/Dao-AILab/flash-attention"
|
| 117 |
)
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
def quantize_cache_v(fdata, bits, qmax, qmin):
|
| 120 |
# b, s, head, h-dim->b, head, s, h-dim
|
| 121 |
qtype = torch.uint8
|
|
@@ -1052,6 +1077,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
| 1052 |
if config.use_flash_attn:
|
| 1053 |
_import_flash_attn()
|
| 1054 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
self.transformer = QWenModel(config)
|
| 1056 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1057 |
|
|
@@ -1412,7 +1443,9 @@ def _rotate_half(x):
|
|
| 1412 |
|
| 1413 |
def apply_rotary_pos_emb(t, freqs):
|
| 1414 |
cos, sin = freqs
|
| 1415 |
-
if
|
|
|
|
|
|
|
| 1416 |
t_ = t.float()
|
| 1417 |
cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
|
| 1418 |
sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
|
|
|
|
| 36 |
from torch import nn
|
| 37 |
|
| 38 |
SUPPORT_CUDA = torch.cuda.is_available()
|
| 39 |
+
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 8
|
| 40 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
| 41 |
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
| 42 |
|
|
|
|
| 77 |
"""
|
| 78 |
|
| 79 |
apply_rotary_emb_func = None
|
| 80 |
+
apply_rotary_emb_func_triton = None
|
| 81 |
rms_norm = None
|
| 82 |
flash_attn_unpadded_func = None
|
| 83 |
|
|
|
|
| 117 |
"https://github.com/Dao-AILab/flash-attention"
|
| 118 |
)
|
| 119 |
|
| 120 |
+
def _import_triton():
|
| 121 |
+
global apply_rotary_emb_func_triton, rms_norm
|
| 122 |
+
try:
|
| 123 |
+
from .triton_kernels import triton, apply_rotary_emb as __apply_rotary_emb, rms_norm as __rms_norm
|
| 124 |
+
if apply_rotary_emb_func is not None:
|
| 125 |
+
logger.warn(
|
| 126 |
+
"rotary kernel imported from flash_attn is replaced by Triton kernel."
|
| 127 |
+
)
|
| 128 |
+
apply_rotary_emb_func_triton = __apply_rotary_emb
|
| 129 |
+
if rms_norm is not None:
|
| 130 |
+
logger.warn(
|
| 131 |
+
"rms_norm kernel imported from flash_attn is replaced by Triton kernel."
|
| 132 |
+
)
|
| 133 |
+
rms_norm = __rms_norm
|
| 134 |
+
except ImportError:
|
| 135 |
+
logger.warn("Warning: Failed to import Triton kernels.")
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
if int(triton.__version__.split(".")[1]) == 0:
|
| 139 |
+
logger.warn(
|
| 140 |
+
"Triton 2.0 is detected in your environment. It is recommended that you upgrade to Triton 2.1 by "
|
| 141 |
+
"`pip install triton==2.1` for better performance if you do not use TorchInductor in PyTorch 2.0."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
def quantize_cache_v(fdata, bits, qmax, qmin):
|
| 145 |
# b, s, head, h-dim->b, head, s, h-dim
|
| 146 |
qtype = torch.uint8
|
|
|
|
| 1077 |
if config.use_flash_attn:
|
| 1078 |
_import_flash_attn()
|
| 1079 |
|
| 1080 |
+
if config.use_triton == "auto":
|
| 1081 |
+
config.use_triton = SUPPORT_TORCH2
|
| 1082 |
+
if config.use_triton:
|
| 1083 |
+
logger.warn("Try importing Triton kernels for faster inference...")
|
| 1084 |
+
_import_triton()
|
| 1085 |
+
|
| 1086 |
self.transformer = QWenModel(config)
|
| 1087 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1088 |
|
|
|
|
| 1443 |
|
| 1444 |
def apply_rotary_pos_emb(t, freqs):
|
| 1445 |
cos, sin = freqs
|
| 1446 |
+
if apply_rotary_emb_func_triton is not None and t.is_cuda:
|
| 1447 |
+
return apply_rotary_emb_func_triton(t, cos, sin)
|
| 1448 |
+
elif apply_rotary_emb_func is not None and t.is_cuda:
|
| 1449 |
t_ = t.float()
|
| 1450 |
cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
|
| 1451 |
sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
|
triton_kernels.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Any, Callable, Dict, Hashable, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import triton
|
| 6 |
+
import triton.language as tl
|
| 7 |
+
from torch.autograd.function import Function, FunctionCtx
|
| 8 |
+
from triton.compiler import CompiledKernel
|
| 9 |
+
from triton.runtime import KernelInterface
|
| 10 |
+
from triton.runtime.autotuner import Autotuner
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import triton.language.math as tlmath # Triton 2.1
|
| 14 |
+
except ImportError:
|
| 15 |
+
import triton.language.libdevice as tlmath # Triton 2.0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TritonKernel:
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
kernel_fn_: KernelInterface,
|
| 22 |
+
grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]],
|
| 23 |
+
) -> None:
|
| 24 |
+
self.kernel_fn_ = kernel_fn_
|
| 25 |
+
self.grid_fn_ = grid_fn
|
| 26 |
+
self.kernel_cache_: Dict[Hashable, CompiledKernel] = {}
|
| 27 |
+
|
| 28 |
+
def run(self, *args, **kwargs):
|
| 29 |
+
# Set current device
|
| 30 |
+
input_device = args[0].device
|
| 31 |
+
prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device()
|
| 32 |
+
if input_device.index != cur_dev_idx:
|
| 33 |
+
prev_dev_idx = cur_dev_idx
|
| 34 |
+
torch.cuda.set_device(input_device.index)
|
| 35 |
+
|
| 36 |
+
# Compute grid
|
| 37 |
+
grid = self.grid_fn_(args)
|
| 38 |
+
|
| 39 |
+
# Use cached kernel if possible
|
| 40 |
+
kernel_key = (input_device,)
|
| 41 |
+
if isinstance(self.kernel_fn_, Autotuner):
|
| 42 |
+
kernel_key += tuple(args[ki] for ki in self.kernel_fn_.key_idx)
|
| 43 |
+
else:
|
| 44 |
+
kernel_key += tuple(kwargs.items())
|
| 45 |
+
if kernel_key in self.kernel_cache_:
|
| 46 |
+
kernel = self.kernel_cache_[kernel_key]
|
| 47 |
+
kernel[grid](*args)
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
# Compile new kernel
|
| 51 |
+
if isinstance(self.kernel_fn_, Autotuner):
|
| 52 |
+
kernel = self.kernel_fn_[grid](*args)
|
| 53 |
+
else:
|
| 54 |
+
kernel = self.kernel_fn_[grid](*args, **kwargs)
|
| 55 |
+
|
| 56 |
+
# Store kernel
|
| 57 |
+
self.kernel_cache_[kernel_key] = kernel
|
| 58 |
+
|
| 59 |
+
# Restore previous device
|
| 60 |
+
torch.cuda.set_device(prev_dev_idx)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@triton.jit
|
| 64 |
+
def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr):
|
| 65 |
+
batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 66 |
+
seq_len, num_heads = tl.num_programs(1), tl.num_programs(2)
|
| 67 |
+
block_idx = tl.arange(0, HEAD_DIM)
|
| 68 |
+
x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM
|
| 69 |
+
x = tl.load(X + x_base_idx + block_idx)
|
| 70 |
+
freq_idx = tok_idx * HEAD_DIM + block_idx
|
| 71 |
+
cos = tl.load(Cos + freq_idx)
|
| 72 |
+
rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM
|
| 73 |
+
x_rot = tl.load(X + x_base_idx + rot_idx)
|
| 74 |
+
x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot)
|
| 75 |
+
sin = tl.load(Sin + freq_idx)
|
| 76 |
+
y_idx = (
|
| 77 |
+
(batch_idx * seq_len + tok_idx) * num_heads + head_idx
|
| 78 |
+
) * HEAD_DIM + block_idx
|
| 79 |
+
y = x * cos + x_rot * sin
|
| 80 |
+
tl.store(Y + y_idx, y)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
apply_rope_fwd_kernel = TritonKernel(
|
| 84 |
+
_apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3])
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class ApplyRotaryEmb(Function):
|
| 89 |
+
@staticmethod
|
| 90 |
+
def forward(
|
| 91 |
+
ctx: FunctionCtx, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| 92 |
+
):
|
| 93 |
+
y = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
| 94 |
+
apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1))
|
| 95 |
+
return y
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
| 99 |
+
return ApplyRotaryEmb.apply(x, cos, sin)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@triton.autotune(
|
| 103 |
+
configs=[
|
| 104 |
+
triton.Config({"BLOCK_SIZE": 4096}),
|
| 105 |
+
triton.Config({"BLOCK_SIZE": 2048}),
|
| 106 |
+
triton.Config({"BLOCK_SIZE": 1024}),
|
| 107 |
+
],
|
| 108 |
+
key=[],
|
| 109 |
+
)
|
| 110 |
+
@triton.jit
|
| 111 |
+
def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr):
|
| 112 |
+
tok_idx = tl.program_id(0)
|
| 113 |
+
|
| 114 |
+
mean_sq = tl.zeros([BLOCK_SIZE], tl.float32)
|
| 115 |
+
for offset in range(0, hidden_dim, BLOCK_SIZE):
|
| 116 |
+
dim_idx = offset + tl.arange(0, BLOCK_SIZE)
|
| 117 |
+
x = tl.load(
|
| 118 |
+
X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0
|
| 119 |
+
).to(tl.float32)
|
| 120 |
+
mean_sq += x * x / hidden_dim
|
| 121 |
+
rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps)
|
| 122 |
+
|
| 123 |
+
for offset in range(0, hidden_dim, BLOCK_SIZE):
|
| 124 |
+
dim_idx = offset + tl.arange(0, BLOCK_SIZE)
|
| 125 |
+
dim_mask = dim_idx < hidden_dim
|
| 126 |
+
hidden_idx = tok_idx * hidden_dim + dim_idx
|
| 127 |
+
x = tl.load(X + hidden_idx, mask=dim_mask, other=0)
|
| 128 |
+
w = tl.load(W + dim_idx, mask=dim_mask, other=0)
|
| 129 |
+
y = x * rrms * w
|
| 130 |
+
tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
rms_norm_fwd_kernel = TritonKernel(
|
| 134 |
+
_rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1)
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class RMSNorm(torch.autograd.Function):
|
| 139 |
+
@staticmethod
|
| 140 |
+
def forward(ctx: FunctionCtx, x: torch.Tensor, w: torch.Tensor, eps: float):
|
| 141 |
+
y = torch.empty_like(x)
|
| 142 |
+
rms_norm_fwd_kernel.run(x, w, y, eps, x.size(-1))
|
| 143 |
+
return y
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
|
| 147 |
+
return RMSNorm.apply(x, weight, eps)
|