Upload custom kernels
Browse files- build/torch-universal/unsloth_kernels/__init__.py +23 -0
- build/torch-universal/unsloth_kernels/_ops.py +8 -0
- build/torch-universal/unsloth_kernels/cross_entropy_loss.py +420 -0
- build/torch-universal/unsloth_kernels/fast_lora.py +537 -0
- build/torch-universal/unsloth_kernels/flex_attention.py +181 -0
- build/torch-universal/unsloth_kernels/geglu.py +213 -0
- build/torch-universal/unsloth_kernels/layernorm.py +170 -0
- build/torch-universal/unsloth_kernels/rms_layernorm.py +261 -0
- build/torch-universal/unsloth_kernels/rope_embedding.py +202 -0
- build/torch-universal/unsloth_kernels/swiglu.py +101 -0
- build/torch-universal/unsloth_kernels/utils.py +497 -0
- flake.lock +117 -0
- flake.nix +2 -2
build/torch-universal/unsloth_kernels/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .cross_entropy_loss import fast_cross_entropy_loss
|
| 2 |
+
from .fast_lora import fast_lora_forward
|
| 3 |
+
from .flex_attention import slow_inference_attention_softcapping
|
| 4 |
+
from .layernorm import fast_layernorm
|
| 5 |
+
from .rope_embedding import inplace_rope_embedding, fast_rope_embedding
|
| 6 |
+
from .rms_layernorm import fast_rms_layernorm
|
| 7 |
+
from .swiglu import swiglu_fg_kernel
|
| 8 |
+
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel, geglu_exact_forward_kernel, geglu_exact_backward_kernel
|
| 9 |
+
from .swiglu import swiglu_fg_kernel
|
| 10 |
+
|
| 11 |
+
__all__ = ["fast_cross_entropy_loss",
|
| 12 |
+
"fast_lora_forward",
|
| 13 |
+
"slow_inference_attention_softcapping",
|
| 14 |
+
"fast_layernorm",
|
| 15 |
+
"inplace_rope_embedding",
|
| 16 |
+
"fast_rms_layernorm",
|
| 17 |
+
"swiglu_fg_kernel",
|
| 18 |
+
"geglu_approx_forward_kernel",
|
| 19 |
+
"geglu_approx_backward_kernel",
|
| 20 |
+
"geglu_exact_forward_kernel",
|
| 21 |
+
"geglu_exact_backward_kernel",
|
| 22 |
+
"fast_rope_embedding"
|
| 23 |
+
]
|
build/torch-universal/unsloth_kernels/_ops.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
ops = torch.ops._unsloth_kernels_a210373_dirty
|
| 3 |
+
|
| 4 |
+
def add_op_namespace_prefix(op_name: str):
|
| 5 |
+
"""
|
| 6 |
+
Prefix op by namespace.
|
| 7 |
+
"""
|
| 8 |
+
return f"_unsloth_kernels_a210373_dirty::{op_name}"
|
build/torch-universal/unsloth_kernels/cross_entropy_loss.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import triton
|
| 16 |
+
import triton.language as tl
|
| 17 |
+
import torch
|
| 18 |
+
from .utils import (
|
| 19 |
+
calculate_settings,
|
| 20 |
+
MAX_FUSED_SIZE,
|
| 21 |
+
triton_tanh,
|
| 22 |
+
triton_cast,
|
| 23 |
+
torch_cuda_device,
|
| 24 |
+
)
|
| 25 |
+
from transformers.models.llama.modeling_llama import logger
|
| 26 |
+
from packaging.version import Version
|
| 27 |
+
|
| 28 |
+
from unsloth_zoo.loss_utils import (
|
| 29 |
+
patch_loss_functions as _patch_loss_functions,
|
| 30 |
+
post_patch_loss_function,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _cross_entropy_forward(
|
| 35 |
+
logits_ptr ,
|
| 36 |
+
logits_row_stride ,
|
| 37 |
+
loss_ptr ,
|
| 38 |
+
logsumexp_ptr ,
|
| 39 |
+
labels_ptr ,
|
| 40 |
+
VOCAB_SIZE : tl.constexpr,
|
| 41 |
+
BLOCK_SIZE : tl.constexpr,
|
| 42 |
+
DO_SOFTCAPPING : tl.constexpr,
|
| 43 |
+
SOFTCAP : tl.constexpr,
|
| 44 |
+
DO_LOGIT_SCALING : tl.constexpr,
|
| 45 |
+
LOGIT_SCALE : tl.constexpr,
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
| 49 |
+
Pi = exp(xi) / sum(exp(xi))
|
| 50 |
+
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
| 51 |
+
= -y [ x - log[sum(exp(x))] ]
|
| 52 |
+
= y * (log[sum(exp(x))] - x)
|
| 53 |
+
If y == 0: CE_i = 0
|
| 54 |
+
If y == 1: CE_i = logsumexp - x
|
| 55 |
+
|
| 56 |
+
logsumexp is also stable
|
| 57 |
+
Take y = log[sum(exp(x))]
|
| 58 |
+
exp(y) = sum(exp(x))
|
| 59 |
+
exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
|
| 60 |
+
exp(y) = exp(c)*sum(exp(x - c))
|
| 61 |
+
y = log(exp(c)*sum(exp(x - c)))
|
| 62 |
+
y = c + log[sum(exp(x - c))]
|
| 63 |
+
This means we can set c = max(x) to make sure
|
| 64 |
+
exp(x - c) always is exp(x - max(x)).
|
| 65 |
+
This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
|
| 66 |
+
"""
|
| 67 |
+
row_idx = tl.program_id(0)
|
| 68 |
+
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
|
| 69 |
+
loss_ptr += row_idx
|
| 70 |
+
logsumexp_ptr += row_idx
|
| 71 |
+
labels_ptr += row_idx
|
| 72 |
+
|
| 73 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 74 |
+
mask = col_offsets < VOCAB_SIZE
|
| 75 |
+
|
| 76 |
+
label_idx = tl.load(labels_ptr).to(tl.int32)
|
| 77 |
+
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
| 78 |
+
|
| 79 |
+
# Go logit scaling for Cohere: t * x
|
| 80 |
+
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
|
| 81 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
| 82 |
+
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
|
| 83 |
+
|
| 84 |
+
c = tl.max(logits, 0)
|
| 85 |
+
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
|
| 86 |
+
|
| 87 |
+
if label_idx != -100:
|
| 88 |
+
x = tl.load(logits_ptr + label_idx).to(tl.float32)
|
| 89 |
+
# Go logit scaling for Cohere: t * x
|
| 90 |
+
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
|
| 91 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
| 92 |
+
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
|
| 93 |
+
loss = logsumexp - x
|
| 94 |
+
else:
|
| 95 |
+
loss = 0.0
|
| 96 |
+
tl.store(logsumexp_ptr, logsumexp)
|
| 97 |
+
tl.store(loss_ptr, loss)
|
| 98 |
+
pass
|
| 99 |
+
_cross_entropy_forward = triton.jit(_cross_entropy_forward)
|
| 100 |
+
_cross_entropy_forward = triton.heuristics(
|
| 101 |
+
{
|
| 102 |
+
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
| 103 |
+
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
| 104 |
+
}
|
| 105 |
+
)(_cross_entropy_forward)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _chunked_cross_entropy_forward(
|
| 109 |
+
logits_ptr ,
|
| 110 |
+
logits_row_stride ,
|
| 111 |
+
loss_ptr ,
|
| 112 |
+
logsumexp_ptr ,
|
| 113 |
+
labels_ptr ,
|
| 114 |
+
VOCAB_SIZE : tl.constexpr,
|
| 115 |
+
N_CHUNKS : tl.constexpr,
|
| 116 |
+
BLOCK_SIZE : tl.constexpr,
|
| 117 |
+
DO_SOFTCAPPING : tl.constexpr,
|
| 118 |
+
SOFTCAP : tl.constexpr,
|
| 119 |
+
DO_LOGIT_SCALING : tl.constexpr,
|
| 120 |
+
LOGIT_SCALE : tl.constexpr,
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
256K vocab divided in 4 chunks
|
| 124 |
+
|
| 125 |
+
|-65536-| |-65536-| |-65536-| |-65536-|
|
| 126 |
+
|-------| |-------| |-------| |-------|
|
| 127 |
+
|-------| |-------| |-------| |-------|
|
| 128 |
+
|
| 129 |
+
If y == 0: CE_i = 0
|
| 130 |
+
If y == 1: CE_i = logsumexp - x
|
| 131 |
+
|
| 132 |
+
Notice we can do logsumexp for each chunk and then
|
| 133 |
+
logsumexp[chunk_sum(logsumexp)] == logsumexp
|
| 134 |
+
|
| 135 |
+
chunk_sum = log[chunk_sum(logsumexp)]
|
| 136 |
+
= log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
|
| 137 |
+
= log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
|
| 138 |
+
= log[sum(exp(a)) + ... + sum(exp(z))]
|
| 139 |
+
= logsumexp(x)
|
| 140 |
+
|
| 141 |
+
This means we can perform a logsumexp for each chunk, then do a
|
| 142 |
+
final logsumexp reduction!
|
| 143 |
+
|
| 144 |
+
Ie do: logsumexp(chunked_logsumexp) - x
|
| 145 |
+
"""
|
| 146 |
+
row_idx = tl.program_id(0)
|
| 147 |
+
chunk_idx = tl.program_id(1)
|
| 148 |
+
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
|
| 149 |
+
loss_ptr += row_idx
|
| 150 |
+
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
|
| 151 |
+
labels_ptr += row_idx
|
| 152 |
+
|
| 153 |
+
col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 154 |
+
mask = col_offsets < VOCAB_SIZE
|
| 155 |
+
|
| 156 |
+
label_idx = tl.load(labels_ptr).to(tl.int32)
|
| 157 |
+
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
| 158 |
+
|
| 159 |
+
# Go logit scaling for Cohere: t * x
|
| 160 |
+
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
|
| 161 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
| 162 |
+
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
|
| 163 |
+
|
| 164 |
+
c = tl.max(logits, 0)
|
| 165 |
+
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
|
| 166 |
+
|
| 167 |
+
if chunk_idx == 0:
|
| 168 |
+
# logsumexp(chunked_logsumexp) - x
|
| 169 |
+
# Do the -x separately
|
| 170 |
+
if label_idx != -100:
|
| 171 |
+
x = tl.load(logits_ptr + label_idx).to(tl.float32)
|
| 172 |
+
# Go logit scaling for Cohere: t * x
|
| 173 |
+
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
|
| 174 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
| 175 |
+
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
|
| 176 |
+
loss = -1.0 * x
|
| 177 |
+
else:
|
| 178 |
+
loss = 0.0
|
| 179 |
+
tl.store(loss_ptr, loss)
|
| 180 |
+
pass
|
| 181 |
+
tl.store(logsumexp_ptr, logsumexp)
|
| 182 |
+
pass
|
| 183 |
+
_chunked_cross_entropy_forward = triton.jit(_chunked_cross_entropy_forward)
|
| 184 |
+
_chunked_cross_entropy_forward = triton.heuristics(
|
| 185 |
+
{
|
| 186 |
+
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
| 187 |
+
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
| 188 |
+
}
|
| 189 |
+
)(_chunked_cross_entropy_forward)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _cross_entropy_backward(
|
| 193 |
+
logits_ptr ,
|
| 194 |
+
logits_row_stride ,
|
| 195 |
+
dloss_ptr ,
|
| 196 |
+
dloss_row_stride ,
|
| 197 |
+
logsumexp_ptr ,
|
| 198 |
+
labels_ptr ,
|
| 199 |
+
VOCAB_SIZE : tl.constexpr,
|
| 200 |
+
BLOCK_SIZE : tl.constexpr,
|
| 201 |
+
DO_SOFTCAPPING : tl.constexpr,
|
| 202 |
+
SOFTCAP : tl.constexpr,
|
| 203 |
+
DO_LOGIT_SCALING : tl.constexpr,
|
| 204 |
+
LOGIT_SCALE : tl.constexpr,
|
| 205 |
+
):
|
| 206 |
+
"""
|
| 207 |
+
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
| 208 |
+
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
| 209 |
+
|
| 210 |
+
From https://en.wikipedia.org/wiki/LogSumExp
|
| 211 |
+
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
| 212 |
+
|
| 213 |
+
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
| 214 |
+
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
| 215 |
+
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
| 216 |
+
|
| 217 |
+
If y == 0: dC/dx = 0
|
| 218 |
+
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
| 219 |
+
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
| 220 |
+
"""
|
| 221 |
+
row_idx = tl.program_id(0)
|
| 222 |
+
block_idx = tl.program_id(1)
|
| 223 |
+
|
| 224 |
+
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
|
| 225 |
+
dloss_ptr += row_idx * dloss_row_stride
|
| 226 |
+
col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 227 |
+
mask = col_offsets < VOCAB_SIZE
|
| 228 |
+
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
| 229 |
+
|
| 230 |
+
if label_idx != -100:
|
| 231 |
+
dloss = tl.load(dloss_ptr)
|
| 232 |
+
else:
|
| 233 |
+
dloss = 0.0
|
| 234 |
+
|
| 235 |
+
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
| 236 |
+
|
| 237 |
+
# Do logit scaling for Cohere
|
| 238 |
+
if DO_LOGIT_SCALING:
|
| 239 |
+
# d/dx [s * x] = s
|
| 240 |
+
x = x * LOGIT_SCALE
|
| 241 |
+
pass
|
| 242 |
+
|
| 243 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
| 244 |
+
partial = x
|
| 245 |
+
if DO_SOFTCAPPING:
|
| 246 |
+
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
|
| 247 |
+
partial = triton_tanh(x / SOFTCAP)
|
| 248 |
+
x = SOFTCAP * partial
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
logsumexp = tl.load(logsumexp_ptr + row_idx)
|
| 252 |
+
y = tl.exp(x - logsumexp)
|
| 253 |
+
y = tl.where(
|
| 254 |
+
col_offsets == label_idx,
|
| 255 |
+
y - 1.0, # exp(x - logsumexp) - 1
|
| 256 |
+
y, # exp(x - logsumexp)
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if DO_LOGIT_SCALING:
|
| 260 |
+
# d/dx [s * x] = s
|
| 261 |
+
y = y * LOGIT_SCALE
|
| 262 |
+
pass
|
| 263 |
+
|
| 264 |
+
if DO_SOFTCAPPING:
|
| 265 |
+
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
|
| 266 |
+
y = y * (1.0 - partial*partial)
|
| 267 |
+
pass
|
| 268 |
+
|
| 269 |
+
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
|
| 270 |
+
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
|
| 271 |
+
pass
|
| 272 |
+
_cross_entropy_backward = triton.jit(_cross_entropy_backward)
|
| 273 |
+
_cross_entropy_backward = triton.heuristics(
|
| 274 |
+
{
|
| 275 |
+
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
| 276 |
+
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
| 277 |
+
}
|
| 278 |
+
)(_cross_entropy_backward)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
MAX_FUSED_SIZE = 65536 # 2**16
|
| 282 |
+
class Fast_CrossEntropyLoss(torch.autograd.Function):
|
| 283 |
+
@staticmethod
|
| 284 |
+
def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0):
|
| 285 |
+
n_rows : int
|
| 286 |
+
vocab_size : int
|
| 287 |
+
n_rows, vocab_size = logits.shape
|
| 288 |
+
device = logits.device
|
| 289 |
+
|
| 290 |
+
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
|
| 291 |
+
n_chunks : int = div + (mod != 0)
|
| 292 |
+
losses = torch.empty(n_rows, dtype = torch.float32, device = device)
|
| 293 |
+
|
| 294 |
+
DO_SOFTCAPPING : bool = bool(logit_softcapping != 0)
|
| 295 |
+
DO_LOGIT_SCALING : bool = bool(logit_scaling != 0)
|
| 296 |
+
|
| 297 |
+
BLOCK_SIZE : int
|
| 298 |
+
num_warps : int
|
| 299 |
+
if n_chunks == 1:
|
| 300 |
+
# For small vocabs <= 65336 like Llama, Mistral
|
| 301 |
+
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
|
| 302 |
+
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device)
|
| 303 |
+
|
| 304 |
+
with torch_cuda_device(device):
|
| 305 |
+
_cross_entropy_forward[(n_rows,)](
|
| 306 |
+
logits, logits.stride(0),
|
| 307 |
+
losses,
|
| 308 |
+
logsumexp,
|
| 309 |
+
labels,
|
| 310 |
+
VOCAB_SIZE = vocab_size,
|
| 311 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
| 312 |
+
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
| 313 |
+
SOFTCAP = logit_softcapping,
|
| 314 |
+
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
|
| 315 |
+
LOGIT_SCALE = logit_scaling,
|
| 316 |
+
num_warps = num_warps,
|
| 317 |
+
)
|
| 318 |
+
else:
|
| 319 |
+
# For large vocabs > 65336 like Gemma 256K
|
| 320 |
+
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = device)
|
| 321 |
+
|
| 322 |
+
with torch_cuda_device(device):
|
| 323 |
+
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
|
| 324 |
+
logits, logits.stride(0),
|
| 325 |
+
losses,
|
| 326 |
+
logsumexp,
|
| 327 |
+
labels,
|
| 328 |
+
VOCAB_SIZE = vocab_size,
|
| 329 |
+
N_CHUNKS = n_chunks,
|
| 330 |
+
BLOCK_SIZE = MAX_FUSED_SIZE,
|
| 331 |
+
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
| 332 |
+
SOFTCAP = logit_softcapping,
|
| 333 |
+
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
|
| 334 |
+
LOGIT_SCALE = logit_scaling,
|
| 335 |
+
num_warps = 32,
|
| 336 |
+
)
|
| 337 |
+
# logsumexp(chunked_logsumexp) - x
|
| 338 |
+
# Do the -x separately
|
| 339 |
+
logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
|
| 340 |
+
losses += logsumexp
|
| 341 |
+
losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
|
| 342 |
+
pass
|
| 343 |
+
|
| 344 |
+
ctx.save_for_backward(logits, logsumexp, labels)
|
| 345 |
+
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
|
| 346 |
+
ctx.logit_softcapping = logit_softcapping
|
| 347 |
+
ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
|
| 348 |
+
ctx.logit_scaling = logit_scaling
|
| 349 |
+
return losses
|
| 350 |
+
pass
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
@staticmethod
|
| 354 |
+
def backward(ctx, dlosses):
|
| 355 |
+
logits, logsumexp, labels = ctx.saved_tensors
|
| 356 |
+
n_rows : int
|
| 357 |
+
vocab_size : int
|
| 358 |
+
n_rows, vocab_size = logits.shape
|
| 359 |
+
|
| 360 |
+
BLOCK_SIZE : int = 4096
|
| 361 |
+
div : int
|
| 362 |
+
mod : int
|
| 363 |
+
div, mod = divmod(vocab_size, BLOCK_SIZE)
|
| 364 |
+
n_blocks : int = div + (mod != 0)
|
| 365 |
+
|
| 366 |
+
with torch_cuda_device(dlosses.device):
|
| 367 |
+
_cross_entropy_backward[(n_rows, n_blocks,)](
|
| 368 |
+
logits, logits.stride(0),
|
| 369 |
+
dlosses, dlosses.stride(0),
|
| 370 |
+
logsumexp,
|
| 371 |
+
labels,
|
| 372 |
+
VOCAB_SIZE = vocab_size,
|
| 373 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
| 374 |
+
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
|
| 375 |
+
SOFTCAP = ctx.logit_softcapping,
|
| 376 |
+
DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
|
| 377 |
+
LOGIT_SCALE = ctx.logit_scaling,
|
| 378 |
+
num_warps = 8,
|
| 379 |
+
)
|
| 380 |
+
return logits, None, None, None,
|
| 381 |
+
pass
|
| 382 |
+
pass
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def fast_cross_entropy_loss(
|
| 386 |
+
logits,
|
| 387 |
+
labels,
|
| 388 |
+
logit_softcapping = 0,
|
| 389 |
+
logit_scaling = 0,
|
| 390 |
+
n_items = None,
|
| 391 |
+
):
|
| 392 |
+
"""
|
| 393 |
+
Arguments:
|
| 394 |
+
logits: (batch, seq_len, vocab_size)
|
| 395 |
+
labels: (batch, seq_len,)
|
| 396 |
+
Returns:
|
| 397 |
+
losses: float
|
| 398 |
+
"""
|
| 399 |
+
batch, seq_len, d = logits.shape
|
| 400 |
+
assert(labels.shape == (batch, seq_len))
|
| 401 |
+
|
| 402 |
+
loss = Fast_CrossEntropyLoss.apply(
|
| 403 |
+
logits.view(batch*seq_len, d),
|
| 404 |
+
labels.view(-1),
|
| 405 |
+
logit_softcapping,
|
| 406 |
+
logit_scaling,
|
| 407 |
+
)
|
| 408 |
+
if n_items is None:
|
| 409 |
+
n_items = torch.count_nonzero(labels != -100)
|
| 410 |
+
return loss.sum() / n_items
|
| 411 |
+
pass
|
| 412 |
+
if (Version(torch.__version__) < Version("2.4.0")) and \
|
| 413 |
+
not hasattr(fast_cross_entropy_loss, "__wrapped__"):
|
| 414 |
+
fast_cross_entropy_loss = torch._disable_dynamo(fast_cross_entropy_loss)
|
| 415 |
+
pass
|
| 416 |
+
|
| 417 |
+
# Patch CE Losses in transformers
|
| 418 |
+
def patch_loss_functions(torch_compile = True):
|
| 419 |
+
_patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)
|
| 420 |
+
pass
|
build/torch-universal/unsloth_kernels/fast_lora.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from .utils import (
|
| 17 |
+
fast_dequantize,
|
| 18 |
+
QUANT_STATE,
|
| 19 |
+
get_lora_parameters,
|
| 20 |
+
get_lora_parameters_bias,
|
| 21 |
+
matmul_lora,
|
| 22 |
+
torch_amp_custom_fwd,
|
| 23 |
+
torch_amp_custom_bwd,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LoRA_MLP(torch.autograd.Function):
|
| 28 |
+
"""
|
| 29 |
+
### LoRA weights
|
| 30 |
+
G = G + Ag @ Bg
|
| 31 |
+
U = U + Au @ Bu
|
| 32 |
+
W = W + Aw @ Bw
|
| 33 |
+
|
| 34 |
+
### SwiGLU(X)
|
| 35 |
+
e = X @ G
|
| 36 |
+
f = e * sigmoid(e)
|
| 37 |
+
g = X @ U
|
| 38 |
+
h = f * g
|
| 39 |
+
i = h @ W
|
| 40 |
+
|
| 41 |
+
### Backpropagation chain rule
|
| 42 |
+
See our blog post for more details
|
| 43 |
+
|
| 44 |
+
df = sigmoid(e) * (1 - f) + f
|
| 45 |
+
dC/dW = h.T @ dY
|
| 46 |
+
dC/dU = X.T @ (D @ W.T * f)
|
| 47 |
+
dC/dG = X.T @ (D @ W.T * df * g)
|
| 48 |
+
|
| 49 |
+
### Down projection LoRA weights
|
| 50 |
+
dC/dAw = dC/dW @ B.T
|
| 51 |
+
dC/dBw = A.T @ dC/dW
|
| 52 |
+
dC/dAw = h.T @ dY @ B.T
|
| 53 |
+
dC/dBw = A.T @ h.T @ dY
|
| 54 |
+
|
| 55 |
+
### Up projection LoRA weights
|
| 56 |
+
dC/dAu = X.T @ (D @ W.T * f) @ B.T
|
| 57 |
+
dC/dBu = A.T @ X.T @ (D @ W.T * f)
|
| 58 |
+
|
| 59 |
+
### Gate projection LoRA weights
|
| 60 |
+
dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
|
| 61 |
+
dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
|
| 62 |
+
|
| 63 |
+
Don't forget to see our blog post for more details!
|
| 64 |
+
"""
|
| 65 |
+
@staticmethod
|
| 66 |
+
@torch_amp_custom_fwd
|
| 67 |
+
def forward(ctx, X : torch.Tensor,
|
| 68 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
| 69 |
+
upW, upW_quant, upA, upB, upS,
|
| 70 |
+
downW, downW_quant, downA, downB, downS,
|
| 71 |
+
_forward_function, _backward_function,
|
| 72 |
+
inplace = True,):
|
| 73 |
+
dtype = X.dtype
|
| 74 |
+
|
| 75 |
+
e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
|
| 76 |
+
g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
|
| 77 |
+
h = _forward_function(e, g)
|
| 78 |
+
i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
|
| 79 |
+
|
| 80 |
+
ctx.custom_saved_tensors = (
|
| 81 |
+
gateW, gateW_quant, gateS,
|
| 82 |
+
upW, upW_quant, upS,
|
| 83 |
+
downW, downW_quant, downS,
|
| 84 |
+
_backward_function,
|
| 85 |
+
)
|
| 86 |
+
ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
|
| 87 |
+
X, e, g)
|
| 88 |
+
ctx.inplace = inplace
|
| 89 |
+
return i
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
@torch_amp_custom_bwd
|
| 95 |
+
def backward(ctx, dY : torch.Tensor):
|
| 96 |
+
gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
|
| 97 |
+
_backward_function = ctx.custom_saved_tensors
|
| 98 |
+
gateA, gateB, upA, upB, downA, downB, \
|
| 99 |
+
X, e, g = ctx.saved_tensors
|
| 100 |
+
|
| 101 |
+
batch, seq_len, hd = X.shape
|
| 102 |
+
dY = dY.view(-1, dY.shape[-1])
|
| 103 |
+
X = X .view(-1, X .shape[-1])
|
| 104 |
+
e = e .view(-1, e .shape[-1])
|
| 105 |
+
g = g .view(-1, g .shape[-1])
|
| 106 |
+
dtype = X.dtype
|
| 107 |
+
|
| 108 |
+
gateA, gateB, upA, upB, downA, downB = \
|
| 109 |
+
gateA.to(dtype), gateB.to(dtype), upA.to(dtype), upB.to(dtype), downA.to(dtype), downB.to(dtype)
|
| 110 |
+
|
| 111 |
+
gateA, gateB, upA, upB, downA, downB = \
|
| 112 |
+
gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
|
| 113 |
+
|
| 114 |
+
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
|
| 115 |
+
DW, e, g = _backward_function(DW, e, g)
|
| 116 |
+
h, df, de = DW, e, g
|
| 117 |
+
|
| 118 |
+
d_downA = torch.empty_like(downA)
|
| 119 |
+
d_downB = torch.empty_like(downB)
|
| 120 |
+
d_gateA = torch.empty_like(gateA)
|
| 121 |
+
d_gateB = torch.empty_like(gateB)
|
| 122 |
+
d_upA = torch.empty_like(upA)
|
| 123 |
+
d_upB = torch.empty_like(upB)
|
| 124 |
+
|
| 125 |
+
# Down projection LoRA weights
|
| 126 |
+
# d_downA = h.t() @ (dY @ downB.t())
|
| 127 |
+
# d_downB = (downA.t() @ h.t()) @ dY
|
| 128 |
+
# d_downA *= downS
|
| 129 |
+
# d_downB *= downS
|
| 130 |
+
d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0)
|
| 131 |
+
d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0)
|
| 132 |
+
|
| 133 |
+
# Up projection LoRA weights
|
| 134 |
+
# d_upA = X.t() @ (df @ upB.t())
|
| 135 |
+
# d_upB = (upA.t() @ X.t()) @ df
|
| 136 |
+
# d_upA *= upS
|
| 137 |
+
# d_upB *= upS
|
| 138 |
+
d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0)
|
| 139 |
+
d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0)
|
| 140 |
+
|
| 141 |
+
# Gate projection LoRA weights
|
| 142 |
+
# d_gateA = X.t() @ (de @ gateB.t())
|
| 143 |
+
# d_gateB = (gateA.t() @ X.t()) @ de
|
| 144 |
+
# d_gateA *= gateS
|
| 145 |
+
# d_gateB *= gateS
|
| 146 |
+
d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0)
|
| 147 |
+
d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0)
|
| 148 |
+
|
| 149 |
+
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
|
| 150 |
+
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
|
| 151 |
+
upW = fast_dequantize(upW.t(), upW_quant)
|
| 152 |
+
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
|
| 153 |
+
del upW
|
| 154 |
+
# dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
|
| 155 |
+
dX.addmm_(df @ upB.t(), upA.t(), alpha = upS)
|
| 156 |
+
|
| 157 |
+
gateW = fast_dequantize(gateW.t(), gateW_quant)
|
| 158 |
+
# dX += de @ gateW.t()
|
| 159 |
+
dX.addmm_(de, gateW.t())
|
| 160 |
+
del gateW
|
| 161 |
+
# dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
|
| 162 |
+
dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS)
|
| 163 |
+
|
| 164 |
+
# gateW, gateW_quant, gateA, gateB, gateS,
|
| 165 |
+
# upW, upW_quant, upA, upB, upS,
|
| 166 |
+
# downW, downW_quant, downA, downB, downS,
|
| 167 |
+
return dX.view(batch, seq_len, hd), \
|
| 168 |
+
None, None, d_gateA.t(), d_gateB.t(), None, \
|
| 169 |
+
None, None, d_upA.t(), d_upB.t(), None, \
|
| 170 |
+
None, None, d_downA.t(), d_downB.t(), None, \
|
| 171 |
+
None, None, None, # _backward and _forward and inplace
|
| 172 |
+
pass
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
|
| 177 |
+
def apply_lora_mlp_swiglu(self, X, inplace = True):
|
| 178 |
+
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
| 179 |
+
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
| 180 |
+
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
| 181 |
+
out = LoRA_MLP.apply(X,
|
| 182 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
| 183 |
+
upW, upW_quant, upA, upB, upS,
|
| 184 |
+
downW, downW_quant, downA, downB, downS,
|
| 185 |
+
swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,
|
| 186 |
+
inplace,)
|
| 187 |
+
return out
|
| 188 |
+
pass
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
|
| 192 |
+
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
|
| 193 |
+
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
| 194 |
+
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
| 195 |
+
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
| 196 |
+
out = LoRA_MLP.apply(X,
|
| 197 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
| 198 |
+
upW, upW_quant, upA, upB, upS,
|
| 199 |
+
downW, downW_quant, downA, downB, downS,
|
| 200 |
+
geglu_exact_forward_kernel, geglu_exact_backward_kernel,
|
| 201 |
+
inplace,)
|
| 202 |
+
return out
|
| 203 |
+
pass
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
|
| 207 |
+
def apply_lora_mlp_geglu_approx(self, X):
|
| 208 |
+
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
| 209 |
+
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
| 210 |
+
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
| 211 |
+
out = LoRA_MLP.apply(X,
|
| 212 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
| 213 |
+
upW, upW_quant, upA, upB, upS,
|
| 214 |
+
downW, downW_quant, downA, downB, downS,
|
| 215 |
+
geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
|
| 216 |
+
return out
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class LoRA_QKV(torch.autograd.Function):
|
| 221 |
+
"""
|
| 222 |
+
### LoRA weights
|
| 223 |
+
Wq = Wq + Aq @ Bq
|
| 224 |
+
Wk = Wk + Ak @ Bk
|
| 225 |
+
Wv = Wv + Av @ Bv
|
| 226 |
+
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
|
| 227 |
+
K = X @ Wk = X @ Wk + X @ Ak @ Bk
|
| 228 |
+
V = X @ Wv = X @ Wv + X @ Av @ Bv
|
| 229 |
+
|
| 230 |
+
### Backpropagation chain rule
|
| 231 |
+
See our blogpost for more details.
|
| 232 |
+
|
| 233 |
+
dC/dWq = X.T @ D(Wq)
|
| 234 |
+
dC/dWk = X.T @ D(Wk)
|
| 235 |
+
dC/dWv = X.T @ D(Wv)
|
| 236 |
+
We then sum them all find dC/dX
|
| 237 |
+
|
| 238 |
+
### Q projection LoRA weights
|
| 239 |
+
dC/dAq = X.T @ D(Wq) @ B.T
|
| 240 |
+
dC/dBq = A.T @ X.T @ D(Wq)
|
| 241 |
+
|
| 242 |
+
### K projection LoRA weights
|
| 243 |
+
dC/dAk = X.T @ D(Wk) @ B.T
|
| 244 |
+
dC/dBk = A.T @ X.T @ D(Wk)
|
| 245 |
+
|
| 246 |
+
### V projection LoRA weights
|
| 247 |
+
dC/dAv = X.T @ D(Wv) @ B.T
|
| 248 |
+
dC/dBv = A.T @ X.T @ D(Wv)
|
| 249 |
+
"""
|
| 250 |
+
@staticmethod
|
| 251 |
+
@torch_amp_custom_fwd
|
| 252 |
+
def forward(ctx, X : torch.Tensor,
|
| 253 |
+
QW, QW_quant, QA, QB, QS,
|
| 254 |
+
KW, KW_quant, KA, KB, KS,
|
| 255 |
+
VW, VW_quant, VA, VB, VS,
|
| 256 |
+
inplace = True):
|
| 257 |
+
dtype = X.dtype
|
| 258 |
+
|
| 259 |
+
Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
|
| 260 |
+
K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
|
| 261 |
+
V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
|
| 262 |
+
|
| 263 |
+
ctx.custom_saved_tensors = (
|
| 264 |
+
QW, QW_quant, QS,
|
| 265 |
+
KW, KW_quant, KS,
|
| 266 |
+
VW, VW_quant, VS,
|
| 267 |
+
)
|
| 268 |
+
ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
|
| 269 |
+
ctx.inplace = inplace
|
| 270 |
+
return Q, K, V
|
| 271 |
+
pass
|
| 272 |
+
|
| 273 |
+
@staticmethod
|
| 274 |
+
@torch_amp_custom_bwd
|
| 275 |
+
def backward(ctx, dQ, dK, dV):
|
| 276 |
+
QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
|
| 277 |
+
ctx.custom_saved_tensors
|
| 278 |
+
X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
|
| 279 |
+
|
| 280 |
+
batch, seq_len, hd = X.shape
|
| 281 |
+
dQ = dQ.view(-1, dQ.shape[-1])
|
| 282 |
+
dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
|
| 283 |
+
dV = dV.view(-1, dV.shape[-1])
|
| 284 |
+
X = X .view(-1, X .shape[-1])
|
| 285 |
+
dtype = X.dtype
|
| 286 |
+
|
| 287 |
+
QA, QB, KA, KB, VA, VB = \
|
| 288 |
+
QA.to(dtype), QB.to(dtype), KA.to(dtype), KB.to(dtype), VA.to(dtype), VB.to(dtype)
|
| 289 |
+
|
| 290 |
+
QA, QB, KA, KB, VA, VB = \
|
| 291 |
+
QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
|
| 292 |
+
|
| 293 |
+
### Weight projection LoRA weights
|
| 294 |
+
# See our blogpost for more details.
|
| 295 |
+
d_QA = torch.empty_like(QA)
|
| 296 |
+
d_QB = torch.empty_like(QB)
|
| 297 |
+
d_KA = torch.empty_like(KA)
|
| 298 |
+
d_KB = torch.empty_like(KB)
|
| 299 |
+
d_VA = torch.empty_like(VA)
|
| 300 |
+
d_VB = torch.empty_like(VB)
|
| 301 |
+
|
| 302 |
+
# Q Projection
|
| 303 |
+
# d_QA = X.t() @ (dQ @ QB.t())
|
| 304 |
+
# d_QB = (QA.t() @ X.t()) @ dQ
|
| 305 |
+
# d_QA *= QS
|
| 306 |
+
# d_QB *= QS
|
| 307 |
+
d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0)
|
| 308 |
+
d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0)
|
| 309 |
+
|
| 310 |
+
# K Projection
|
| 311 |
+
# d_KA = X.t() @ (dK @ KB.t())
|
| 312 |
+
# d_KB = (KA.t() @ X.t()) @ dK
|
| 313 |
+
# d_KA *= KS
|
| 314 |
+
# d_KB *= KS
|
| 315 |
+
d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0)
|
| 316 |
+
d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0)
|
| 317 |
+
|
| 318 |
+
# V Projection
|
| 319 |
+
# d_VA = X.t() @ (dV @ VB.t())
|
| 320 |
+
# d_VB = (VA.t() @ X.t()) @ dV
|
| 321 |
+
# d_VA *= VS
|
| 322 |
+
# d_VB *= VS
|
| 323 |
+
d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0)
|
| 324 |
+
d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0)
|
| 325 |
+
|
| 326 |
+
# Combine derivatives to find dX
|
| 327 |
+
# dQ
|
| 328 |
+
QW = fast_dequantize(QW.t(), QW_quant)
|
| 329 |
+
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
|
| 330 |
+
del QW
|
| 331 |
+
# dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
|
| 332 |
+
dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS)
|
| 333 |
+
|
| 334 |
+
# dK
|
| 335 |
+
KW = fast_dequantize(KW.t(), KW_quant)
|
| 336 |
+
# dX += dK @ KW.t()
|
| 337 |
+
dX.addmm_(dK, KW.t())
|
| 338 |
+
del KW
|
| 339 |
+
# dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
|
| 340 |
+
dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS)
|
| 341 |
+
|
| 342 |
+
# dV
|
| 343 |
+
VW = fast_dequantize(VW.t(), VW_quant)
|
| 344 |
+
# dX += dV @ VW.t()
|
| 345 |
+
dX.addmm_(dV, VW.t())
|
| 346 |
+
del VW
|
| 347 |
+
# dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
|
| 348 |
+
dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS)
|
| 349 |
+
|
| 350 |
+
# QW, QW_quant, QA, QB, QS,
|
| 351 |
+
# KW, KW_quant, KA, KB, KS,
|
| 352 |
+
# VW, VW_quant, VA, VB, VS,
|
| 353 |
+
return dX.view(batch, seq_len, hd), \
|
| 354 |
+
None, None, d_QA.t(), d_QB.t(), None, \
|
| 355 |
+
None, None, d_KA.t(), d_KB.t(), None, \
|
| 356 |
+
None, None, d_VA.t(), d_VB.t(), None, \
|
| 357 |
+
None,
|
| 358 |
+
pass
|
| 359 |
+
pass
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def apply_lora_qkv(self, X, inplace = True):
|
| 363 |
+
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
| 364 |
+
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
| 365 |
+
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
| 366 |
+
Q, K, V = LoRA_QKV.apply(X,
|
| 367 |
+
QW, QW_quant, QA, QB, QS,
|
| 368 |
+
KW, KW_quant, KA, KB, KS,
|
| 369 |
+
VW, VW_quant, VA, VB, VS,
|
| 370 |
+
inplace,
|
| 371 |
+
)
|
| 372 |
+
return Q, K, V
|
| 373 |
+
pass
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class LoRA_W(torch.autograd.Function):
|
| 377 |
+
"""
|
| 378 |
+
### LoRA weights
|
| 379 |
+
Wq = Wq + Aq @ Bq
|
| 380 |
+
Wk = Wk + Ak @ Bk
|
| 381 |
+
Wv = Wv + Av @ Bv
|
| 382 |
+
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
|
| 383 |
+
K = X @ Wk = X @ Wk + X @ Ak @ Bk
|
| 384 |
+
V = X @ Wv = X @ Wv + X @ Av @ Bv
|
| 385 |
+
|
| 386 |
+
### Backpropagation chain rule
|
| 387 |
+
dC/dWq = X.T @ D(Wq)
|
| 388 |
+
dC/dWk = X.T @ D(Wk)
|
| 389 |
+
dC/dWv = X.T @ D(Wv)
|
| 390 |
+
|
| 391 |
+
### Q projection LoRA weights
|
| 392 |
+
dC/dAq = X.T @ D(Wq) @ B.T
|
| 393 |
+
dC/dBq = A.T @ X.T @ D(Wq)
|
| 394 |
+
|
| 395 |
+
### K projection LoRA weights
|
| 396 |
+
dC/dAk = X.T @ D(Wk) @ B.T
|
| 397 |
+
dC/dBk = A.T @ X.T @ D(Wk)
|
| 398 |
+
|
| 399 |
+
### V projection LoRA weights
|
| 400 |
+
dC/dAv = X.T @ D(Wv) @ B.T
|
| 401 |
+
dC/dBv = A.T @ X.T @ D(Wv)
|
| 402 |
+
"""
|
| 403 |
+
@staticmethod
|
| 404 |
+
@torch_amp_custom_fwd
|
| 405 |
+
def forward(ctx, X : torch.Tensor,
|
| 406 |
+
W, W_quant, A, B, S):
|
| 407 |
+
dtype = X.dtype
|
| 408 |
+
XW = matmul_lora(X, W, W_quant, A, B, S)
|
| 409 |
+
ctx.custom_saved_tensors = (W, W_quant, S,)
|
| 410 |
+
ctx.save_for_backward(A, B, X)
|
| 411 |
+
return XW
|
| 412 |
+
pass
|
| 413 |
+
|
| 414 |
+
@staticmethod
|
| 415 |
+
@torch_amp_custom_bwd
|
| 416 |
+
def backward(ctx, dY : torch.Tensor):
|
| 417 |
+
W, W_quant, S = ctx.custom_saved_tensors
|
| 418 |
+
A, B, X = ctx.saved_tensors
|
| 419 |
+
|
| 420 |
+
batch, seq_len, hd = X.shape
|
| 421 |
+
dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
|
| 422 |
+
X = X .reshape(-1, X .shape[-1]) # Must be reshape
|
| 423 |
+
dtype = X.dtype
|
| 424 |
+
|
| 425 |
+
A, B = A.to(dtype), B.to(dtype)
|
| 426 |
+
|
| 427 |
+
A, B = A.t(), B.t()
|
| 428 |
+
|
| 429 |
+
d_A = torch.empty_like(A)
|
| 430 |
+
d_B = torch.empty_like(B)
|
| 431 |
+
|
| 432 |
+
### Weight projection LoRA weights
|
| 433 |
+
# Weight projection
|
| 434 |
+
# d_A = X.t() @ (dY @ B.t())
|
| 435 |
+
# d_B = (A.t() @ X.t()) @ dY
|
| 436 |
+
# d_A *= S
|
| 437 |
+
# d_B *= S
|
| 438 |
+
d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0)
|
| 439 |
+
d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0)
|
| 440 |
+
|
| 441 |
+
# Get derivative for dX
|
| 442 |
+
W = fast_dequantize(W.t(), W_quant)
|
| 443 |
+
dX = dY @ W.t()
|
| 444 |
+
del W
|
| 445 |
+
# dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
|
| 446 |
+
dX.addmm_(dY @ B.t(), A.t(), alpha = S)
|
| 447 |
+
|
| 448 |
+
# W, W_quant, A, B, S
|
| 449 |
+
return dX.view(batch, seq_len, hd), \
|
| 450 |
+
None, None, d_A.t(), d_B.t(), None
|
| 451 |
+
pass
|
| 452 |
+
pass
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def apply_lora_o(self, X):
|
| 456 |
+
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
| 457 |
+
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
|
| 458 |
+
return O
|
| 459 |
+
pass
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
IDENTITY_DROPOUT = torch.nn.Identity
|
| 463 |
+
@torch._disable_dynamo
|
| 464 |
+
def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 465 |
+
raise NotImplementedError(
|
| 466 |
+
"Unsloth: Currently not supported yet - reshaping done incorrectly"
|
| 467 |
+
)
|
| 468 |
+
self._check_forward_args(x, *args, **kwargs)
|
| 469 |
+
adapter_names = kwargs.pop("adapter_names", None)
|
| 470 |
+
|
| 471 |
+
if self.disable_adapters:
|
| 472 |
+
if self.merged:
|
| 473 |
+
self.unmerge()
|
| 474 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 475 |
+
elif adapter_names is not None:
|
| 476 |
+
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
| 477 |
+
elif self.merged:
|
| 478 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 479 |
+
else:
|
| 480 |
+
# Fastpath
|
| 481 |
+
if len(self.active_adapters) == 1:
|
| 482 |
+
active_adapter = self.active_adapters[0]
|
| 483 |
+
if active_adapter not in self.lora_A.keys(): return self.base_layer(x, *args, **kwargs)
|
| 484 |
+
|
| 485 |
+
dropout = self.lora_dropout[active_adapter]
|
| 486 |
+
if isinstance(dropout, IDENTITY_DROPOUT) and not self.use_dora[active_adapter]:
|
| 487 |
+
lora_A = self.lora_A[active_adapter].weight
|
| 488 |
+
lora_B = self.lora_B[active_adapter].weight
|
| 489 |
+
scaling = self.scaling[active_adapter]
|
| 490 |
+
W = self.base_layer.weight
|
| 491 |
+
return LoRA_W.apply(x, W, QUANT_STATE(W), lora_A, lora_B, scaling)
|
| 492 |
+
pass
|
| 493 |
+
pass
|
| 494 |
+
|
| 495 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 496 |
+
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
|
| 497 |
+
# The reason is that in some cases, an error can occur that backprop
|
| 498 |
+
# does not work on a manipulated view. This issue may be solved with
|
| 499 |
+
# newer PyTorch versions but this would need extensive testing to be
|
| 500 |
+
# sure.
|
| 501 |
+
result = result.clone()
|
| 502 |
+
|
| 503 |
+
for active_adapter in self.active_adapters:
|
| 504 |
+
if active_adapter not in self.lora_A.keys():
|
| 505 |
+
continue
|
| 506 |
+
lora_A = self.lora_A[active_adapter]
|
| 507 |
+
lora_B = self.lora_B[active_adapter]
|
| 508 |
+
dropout = self.lora_dropout[active_adapter]
|
| 509 |
+
scaling = self.scaling[active_adapter]
|
| 510 |
+
|
| 511 |
+
requires_conversion = not torch.is_autocast_enabled()
|
| 512 |
+
if requires_conversion:
|
| 513 |
+
expected_dtype = result.dtype
|
| 514 |
+
x = x.to(lora_A.weight.dtype)
|
| 515 |
+
|
| 516 |
+
if not self.use_dora[active_adapter]:
|
| 517 |
+
result = result + lora_B(lora_A(dropout(x))) * scaling
|
| 518 |
+
else:
|
| 519 |
+
if isinstance(dropout, torch.nn.Identity) or not self.training:
|
| 520 |
+
base_result = result
|
| 521 |
+
else:
|
| 522 |
+
x = dropout(x)
|
| 523 |
+
base_result = None
|
| 524 |
+
|
| 525 |
+
result = result + self.lora_magnitude_vector[active_adapter](
|
| 526 |
+
x,
|
| 527 |
+
lora_A=lora_A,
|
| 528 |
+
lora_B=lora_B,
|
| 529 |
+
scaling=scaling,
|
| 530 |
+
base_layer=self.get_base_layer(),
|
| 531 |
+
base_result=base_result,
|
| 532 |
+
)
|
| 533 |
+
if requires_conversion:
|
| 534 |
+
result = result.to(expected_dtype)
|
| 535 |
+
|
| 536 |
+
return result
|
| 537 |
+
pass
|
build/torch-universal/unsloth_kernels/flex_attention.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from functools import lru_cache
|
| 17 |
+
from transformers.models.llama.modeling_llama import logger
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
torch_compile_options = {
|
| 21 |
+
"epilogue_fusion" : True,
|
| 22 |
+
"max_autotune" : True,
|
| 23 |
+
"shape_padding" : True,
|
| 24 |
+
"trace.enabled" : os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1",
|
| 25 |
+
"triton.cudagraphs" : False,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
# Flex Attention supported from torch 2.5 onwards only
|
| 29 |
+
try:
|
| 30 |
+
from torch.nn.attention.flex_attention import (
|
| 31 |
+
flex_attention as _flex_attention,
|
| 32 |
+
create_block_mask as _create_block_mask,
|
| 33 |
+
)
|
| 34 |
+
_flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
|
| 35 |
+
HAS_FLEX_ATTENTION = False
|
| 36 |
+
except:
|
| 37 |
+
HAS_FLEX_ATTENTION = False
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if not HAS_FLEX_ATTENTION:
|
| 42 |
+
|
| 43 |
+
# Logit softcapping
|
| 44 |
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
| 45 |
+
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
| 46 |
+
n_heads = self.config.num_attention_heads
|
| 47 |
+
head_dim = self.head_dim
|
| 48 |
+
n_kv_heads = self.config.num_key_value_heads
|
| 49 |
+
n_groups = self.num_key_value_groups
|
| 50 |
+
|
| 51 |
+
# Grouped query attention
|
| 52 |
+
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
| 53 |
+
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
| 54 |
+
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
| 55 |
+
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
| 56 |
+
|
| 57 |
+
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
| 58 |
+
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
|
| 59 |
+
# We default to using the config file itself
|
| 60 |
+
# s = self.config.hidden_size // self.config.num_attention_heads
|
| 61 |
+
s = self.config.query_pre_attn_scalar
|
| 62 |
+
t = self.config.attn_logit_softcapping
|
| 63 |
+
|
| 64 |
+
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
|
| 65 |
+
A = torch.matmul(Q, K.transpose(2, 3))
|
| 66 |
+
A = t * torch.tanh(A / t) # Logit softcapping
|
| 67 |
+
A += causal_mask[:q_len, :q_len]
|
| 68 |
+
# Much slower in torch compile!
|
| 69 |
+
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
|
| 70 |
+
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
|
| 71 |
+
A = torch.matmul(A, V)
|
| 72 |
+
A = A.transpose(1, 2).contiguous()
|
| 73 |
+
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
| 74 |
+
return A
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
create_flex_attention_causal_mask = None
|
| 78 |
+
create_flex_attention_sliding_window_mask = None
|
| 79 |
+
else:
|
| 80 |
+
# See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
|
| 81 |
+
# for more examples
|
| 82 |
+
# BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al
|
| 83 |
+
import functools, math
|
| 84 |
+
|
| 85 |
+
def generate_tanh_softcap(t):
|
| 86 |
+
def tanh_softcap(x, b, h, q_idx, kv_idx):
|
| 87 |
+
return t * torch.tanh(x / t)
|
| 88 |
+
return tanh_softcap
|
| 89 |
+
pass
|
| 90 |
+
def causal_masker(b, h, q_idx, kv_idx):
|
| 91 |
+
return q_idx >= kv_idx
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
@functools.lru_cache
|
| 95 |
+
def sliding_window_masker(size = 4096):
|
| 96 |
+
def sliding_window(b, h, q_idx, kv_idx):
|
| 97 |
+
causal_mask = q_idx >= kv_idx
|
| 98 |
+
window_mask = q_idx - kv_idx <= size
|
| 99 |
+
return causal_mask & window_mask
|
| 100 |
+
return sliding_window
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
@functools.lru_cache
|
| 104 |
+
def create_block_mask(mask, n = 128):
|
| 105 |
+
return _create_block_mask(
|
| 106 |
+
mask, 1, 1, n, n,
|
| 107 |
+
BLOCK_SIZE = 128,
|
| 108 |
+
_compile = True,
|
| 109 |
+
)
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
def create_flex_attention_causal_mask(max_seq_length = 8192):
|
| 113 |
+
causal_mask = create_block_mask(causal_masker, max_seq_length)
|
| 114 |
+
return causal_mask
|
| 115 |
+
pass
|
| 116 |
+
|
| 117 |
+
def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096):
|
| 118 |
+
sliding_masker = sliding_window_masker(sliding_window)
|
| 119 |
+
causal_mask = create_block_mask(sliding_masker, max_seq_length)
|
| 120 |
+
return causal_mask
|
| 121 |
+
pass
|
| 122 |
+
|
| 123 |
+
@functools.lru_cache
|
| 124 |
+
def flex_attention(s, t):
|
| 125 |
+
scale = 1.0 / math.sqrt(s)
|
| 126 |
+
score_mod = generate_tanh_softcap(t)
|
| 127 |
+
return functools.partial(
|
| 128 |
+
_flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True,
|
| 129 |
+
)
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
+
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
| 133 |
+
n_heads = self.config.num_attention_heads
|
| 134 |
+
head_dim = self.head_dim
|
| 135 |
+
s = self.config.query_pre_attn_scalar
|
| 136 |
+
t = self.config.attn_logit_softcapping
|
| 137 |
+
fx = flex_attention(s, t)
|
| 138 |
+
A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
|
| 139 |
+
A = A.transpose(1, 2).contiguous()
|
| 140 |
+
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
| 141 |
+
return A
|
| 142 |
+
pass
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
torch_matmul = torch.matmul
|
| 147 |
+
torch_tanh = torch.tanh
|
| 148 |
+
torch_nn_functional_softmax = torch.nn.functional.softmax
|
| 149 |
+
def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
| 150 |
+
n_heads = self.config.num_attention_heads
|
| 151 |
+
head_dim = self.head_dim
|
| 152 |
+
n_kv_heads = self.config.num_key_value_heads
|
| 153 |
+
n_groups = self.num_key_value_groups
|
| 154 |
+
|
| 155 |
+
# Grouped query attention
|
| 156 |
+
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
| 157 |
+
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
| 158 |
+
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
| 159 |
+
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
| 160 |
+
|
| 161 |
+
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
| 162 |
+
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
|
| 163 |
+
# We default to using the config file itself
|
| 164 |
+
# s = self.config.hidden_size // self.config.num_attention_heads
|
| 165 |
+
s = self.config.query_pre_attn_scalar
|
| 166 |
+
t = self.config.attn_logit_softcapping
|
| 167 |
+
|
| 168 |
+
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
|
| 169 |
+
A = torch_matmul(Q, K.transpose(2, 3))
|
| 170 |
+
|
| 171 |
+
# Logit softcapping
|
| 172 |
+
A /= t; torch_tanh(A, out = A); A *= t;
|
| 173 |
+
A += causal_mask[:q_len, :q_len]
|
| 174 |
+
# Much slower in torch compile!
|
| 175 |
+
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
|
| 176 |
+
A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
|
| 177 |
+
A = torch_matmul(A, V)
|
| 178 |
+
A = A.transpose(1, 2).contiguous()
|
| 179 |
+
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
| 180 |
+
return A
|
| 181 |
+
pass
|
build/torch-universal/unsloth_kernels/geglu.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import triton
|
| 16 |
+
import triton.language as tl
|
| 17 |
+
import torch
|
| 18 |
+
from .utils import (
|
| 19 |
+
calculate_settings,
|
| 20 |
+
triton_tanh,
|
| 21 |
+
torch_cuda_device,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@triton.jit
|
| 26 |
+
def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
| 27 |
+
block_idx = tl.program_id(0)
|
| 28 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 29 |
+
mask = offsets < n_elements
|
| 30 |
+
|
| 31 |
+
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
| 32 |
+
# h = f * up
|
| 33 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
| 34 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
| 35 |
+
|
| 36 |
+
f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
|
| 37 |
+
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
| 38 |
+
h_row = f_row * g_row
|
| 39 |
+
|
| 40 |
+
# Store h
|
| 41 |
+
tl.store(h + offsets, h_row, mask = mask)
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def geglu_exact_forward_kernel(gate, up):
|
| 46 |
+
batch, seq_len, hd = gate.shape
|
| 47 |
+
n_elements = gate.numel()
|
| 48 |
+
device = gate.device
|
| 49 |
+
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
|
| 50 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
| 51 |
+
with torch_cuda_device(device):
|
| 52 |
+
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
|
| 53 |
+
return out
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@triton.jit
|
| 58 |
+
def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
| 59 |
+
"""
|
| 60 |
+
f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
| 61 |
+
h = f * up
|
| 62 |
+
|
| 63 |
+
df/de (with help of Wolfram :)
|
| 64 |
+
df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
|
| 65 |
+
|
| 66 |
+
Reuse via
|
| 67 |
+
f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
|
| 68 |
+
"""
|
| 69 |
+
block_idx = tl.program_id(0)
|
| 70 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 71 |
+
mask = offsets < n_elements
|
| 72 |
+
|
| 73 |
+
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
| 74 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
| 75 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
| 76 |
+
|
| 77 |
+
# Break e_row away for re-use
|
| 78 |
+
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
| 79 |
+
f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
|
| 80 |
+
f_row = f_partial_row * e_row
|
| 81 |
+
|
| 82 |
+
f_row = f_row.to(DW_row.dtype)
|
| 83 |
+
# h = f * g
|
| 84 |
+
h_row = f_row * g_row
|
| 85 |
+
# df = DW * f
|
| 86 |
+
df_row = DW_row * f_row
|
| 87 |
+
# dg = DW * g
|
| 88 |
+
dg_row = DW_row * g_row
|
| 89 |
+
|
| 90 |
+
# df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
|
| 91 |
+
t = 0.3989422804014327 # 1/sqrt(2*pi)
|
| 92 |
+
df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
|
| 93 |
+
|
| 94 |
+
de_row = dg_row.to(tl.float32) * df_de
|
| 95 |
+
de_row = de_row.to(DW_row.dtype)
|
| 96 |
+
|
| 97 |
+
# Store derivatives in buffers
|
| 98 |
+
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
| 99 |
+
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
| 100 |
+
tl.store(g + offsets, de_row, mask = mask) # de
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def geglu_exact_backward_kernel(DW, e, g):
|
| 105 |
+
batch_seq_len, hd = e.shape
|
| 106 |
+
n_elements = e.numel()
|
| 107 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
| 108 |
+
with torch_cuda_device(e.device):
|
| 109 |
+
_exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
| 110 |
+
return DW, e, g
|
| 111 |
+
pass
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@triton.jit
|
| 115 |
+
def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
| 116 |
+
block_idx = tl.program_id(0)
|
| 117 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 118 |
+
mask = offsets < n_elements
|
| 119 |
+
|
| 120 |
+
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
|
| 121 |
+
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
|
| 122 |
+
# h = f * up
|
| 123 |
+
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
|
| 124 |
+
|
| 125 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
| 126 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
| 127 |
+
|
| 128 |
+
f_row = 0.5 * e_row * (
|
| 129 |
+
triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
|
| 130 |
+
+ 1.0
|
| 131 |
+
)
|
| 132 |
+
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
| 133 |
+
h_row = f_row * g_row
|
| 134 |
+
|
| 135 |
+
# Store h
|
| 136 |
+
tl.store(h + offsets, h_row, mask = mask)
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def geglu_approx_forward_kernel(gate, up):
|
| 141 |
+
batch, seq_len, hd = gate.shape
|
| 142 |
+
n_elements = gate.numel()
|
| 143 |
+
device = gate.device
|
| 144 |
+
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
|
| 145 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
| 146 |
+
with torch_cuda_device(device):
|
| 147 |
+
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
|
| 148 |
+
return out
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@triton.jit
|
| 153 |
+
def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
| 154 |
+
"""
|
| 155 |
+
f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
|
| 156 |
+
h = f * up
|
| 157 |
+
|
| 158 |
+
df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :))
|
| 159 |
+
df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] +
|
| 160 |
+
1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \
|
| 161 |
+
( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) )
|
| 162 |
+
|
| 163 |
+
Notice sech^2(x) = 1 - tanh^2(x)
|
| 164 |
+
So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )
|
| 165 |
+
|
| 166 |
+
See https://www.desmos.com/calculator/nqprfoni6x
|
| 167 |
+
"""
|
| 168 |
+
block_idx = tl.program_id(0)
|
| 169 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 170 |
+
mask = offsets < n_elements
|
| 171 |
+
|
| 172 |
+
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
| 173 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
| 174 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
| 175 |
+
|
| 176 |
+
# See https://www.desmos.com/calculator/nqprfoni6x
|
| 177 |
+
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
|
| 178 |
+
a = s * e_row # a = sqrt(2 / pi) * x
|
| 179 |
+
b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
|
| 180 |
+
T = 1.0 + triton_tanh(a + b)
|
| 181 |
+
T2 = 0.5 * T
|
| 182 |
+
# Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
|
| 183 |
+
Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
|
| 184 |
+
df_de = T2 + Q2 # 1/2 * (T + Q)
|
| 185 |
+
|
| 186 |
+
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
|
| 187 |
+
f_row = T2 * e_row
|
| 188 |
+
f_row = f_row.to(DW_row.dtype)
|
| 189 |
+
# h = f * g
|
| 190 |
+
h_row = f_row * g_row
|
| 191 |
+
# df = DW * f
|
| 192 |
+
df_row = DW_row * f_row
|
| 193 |
+
# dg = DW * g
|
| 194 |
+
dg_row = DW_row * g_row
|
| 195 |
+
|
| 196 |
+
de_row = dg_row.to(tl.float32) * df_de
|
| 197 |
+
de_row = de_row.to(DW_row.dtype)
|
| 198 |
+
|
| 199 |
+
# Store derivatives in buffers
|
| 200 |
+
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
| 201 |
+
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
| 202 |
+
tl.store(g + offsets, de_row, mask = mask) # de
|
| 203 |
+
pass
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def geglu_approx_backward_kernel(DW, e, g):
|
| 207 |
+
batch_seq_len, hd = e.shape
|
| 208 |
+
n_elements = e.numel()
|
| 209 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
| 210 |
+
with torch_cuda_device(e.device):
|
| 211 |
+
_approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
| 212 |
+
return DW, e, g
|
| 213 |
+
pass
|
build/torch-universal/unsloth_kernels/layernorm.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
| 2 |
+
# Copyright 2024-present Andrej Karpathy & the llm.c team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
import torch
|
| 19 |
+
from .utils import calculate_settings, torch_cuda_device
|
| 20 |
+
from unsloth_zoo.patching_utils import (
|
| 21 |
+
patch_layernorm,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@triton.jit
|
| 26 |
+
def layernorm_forward(
|
| 27 |
+
Y, Y_row_stride,
|
| 28 |
+
X, X_row_stride,
|
| 29 |
+
W,
|
| 30 |
+
b,
|
| 31 |
+
r,
|
| 32 |
+
mu,
|
| 33 |
+
n_cols : tl.constexpr,
|
| 34 |
+
eps : tl.constexpr,
|
| 35 |
+
BLOCK_SIZE : tl.constexpr
|
| 36 |
+
):
|
| 37 |
+
row_idx = tl.program_id(0)
|
| 38 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 39 |
+
mask = col_offsets < n_cols
|
| 40 |
+
|
| 41 |
+
Y += row_idx * Y_row_stride
|
| 42 |
+
X += row_idx * X_row_stride
|
| 43 |
+
r += row_idx
|
| 44 |
+
mu += row_idx
|
| 45 |
+
|
| 46 |
+
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
|
| 47 |
+
# are in float32!
|
| 48 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 49 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 50 |
+
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 51 |
+
|
| 52 |
+
mean_X = tl.sum(X_row, axis = 0) / n_cols
|
| 53 |
+
# (X[0] - mean) == -mean so we need to mask it out
|
| 54 |
+
XX = tl.where(mask, X_row - mean_X, 0)
|
| 55 |
+
row_var = tl.sum(XX * XX, axis = 0) / n_cols
|
| 56 |
+
inv_var = tl.math.rsqrt(row_var + eps)
|
| 57 |
+
tl.store (r, inv_var)
|
| 58 |
+
tl.store (mu, mean_X)
|
| 59 |
+
output = (XX * inv_var) * W_row + b_row
|
| 60 |
+
tl.store(Y + col_offsets, output, mask = mask)
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@triton.jit
|
| 65 |
+
def layernorm_backward(
|
| 66 |
+
dY, dY_row_stride,
|
| 67 |
+
X, X_row_stride,
|
| 68 |
+
W,
|
| 69 |
+
b,
|
| 70 |
+
r,
|
| 71 |
+
mu,
|
| 72 |
+
n_cols : tl.constexpr,
|
| 73 |
+
eps : tl.constexpr,
|
| 74 |
+
BLOCK_SIZE : tl.constexpr
|
| 75 |
+
):
|
| 76 |
+
# Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
| 77 |
+
row_idx = tl.program_id(0)
|
| 78 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 79 |
+
mask = col_offsets < n_cols
|
| 80 |
+
|
| 81 |
+
dY += row_idx * dY_row_stride
|
| 82 |
+
X += row_idx * X_row_stride
|
| 83 |
+
r += row_idx
|
| 84 |
+
mu += row_idx
|
| 85 |
+
|
| 86 |
+
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
|
| 87 |
+
# are in float32!
|
| 88 |
+
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 89 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 90 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 91 |
+
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 92 |
+
|
| 93 |
+
inv_var = tl.load(r) .to(tl.float32)
|
| 94 |
+
mean = tl.load(mu).to(tl.float32)
|
| 95 |
+
normed = (X_row - mean) * inv_var
|
| 96 |
+
dY_W = dY_row * W_row
|
| 97 |
+
dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
|
| 98 |
+
dX_row = dX_row * inv_var
|
| 99 |
+
tl.store(dY + col_offsets, dX_row, mask = mask)
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Fast_Layernorm(torch.autograd.Function):
|
| 104 |
+
@staticmethod
|
| 105 |
+
def forward(ctx, X, W, b, eps):
|
| 106 |
+
shape = X.shape
|
| 107 |
+
dim = shape[-1]
|
| 108 |
+
X = X.view(-1, dim)
|
| 109 |
+
n_rows, n_cols = X.shape
|
| 110 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
| 111 |
+
device = X.device
|
| 112 |
+
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
|
| 113 |
+
r = torch.empty(n_rows, dtype = torch.float32, device = device)
|
| 114 |
+
mu = torch.empty(n_rows, dtype = torch.float32, device = device)
|
| 115 |
+
|
| 116 |
+
with torch_cuda_device(device):
|
| 117 |
+
layernorm_forward[(n_rows,)](
|
| 118 |
+
Y, Y.stride(0),
|
| 119 |
+
X, X.stride(0),
|
| 120 |
+
W,
|
| 121 |
+
b,
|
| 122 |
+
r,
|
| 123 |
+
mu,
|
| 124 |
+
n_cols, eps,
|
| 125 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
| 126 |
+
num_warps = num_warps,
|
| 127 |
+
)
|
| 128 |
+
ctx.eps = eps
|
| 129 |
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
| 130 |
+
ctx.num_warps = num_warps
|
| 131 |
+
ctx.save_for_backward(X, W, b, r, mu)
|
| 132 |
+
return Y.view(*shape)
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def backward(ctx, dY):
|
| 137 |
+
shape = dY.shape
|
| 138 |
+
dim = shape[-1]
|
| 139 |
+
dY = dY.view(-1, dim)
|
| 140 |
+
X, W, b, r, mu = ctx.saved_tensors
|
| 141 |
+
n_rows, n_cols = dY.shape
|
| 142 |
+
|
| 143 |
+
with torch_cuda_device(dY.device):
|
| 144 |
+
layernorm_backward[(n_rows,)](
|
| 145 |
+
dY, dY.stride(0),
|
| 146 |
+
X, X .stride(0),
|
| 147 |
+
W,
|
| 148 |
+
b,
|
| 149 |
+
r,
|
| 150 |
+
mu,
|
| 151 |
+
n_cols, ctx.eps,
|
| 152 |
+
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
| 153 |
+
num_warps = ctx.num_warps,
|
| 154 |
+
)
|
| 155 |
+
dX = dY.view(*shape)
|
| 156 |
+
return dX, None, None, None, None
|
| 157 |
+
pass
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def fast_layernorm(layernorm, X):
|
| 162 |
+
assert(layernorm.elementwise_affine is True)
|
| 163 |
+
W = layernorm.weight
|
| 164 |
+
bias = layernorm.bias
|
| 165 |
+
eps = layernorm.variance_epsilon if \
|
| 166 |
+
hasattr(layernorm, "variance_epsilon") \
|
| 167 |
+
else layernorm.eps
|
| 168 |
+
out = Fast_Layernorm.apply(X, W, bias, eps)
|
| 169 |
+
return out
|
| 170 |
+
pass
|
build/torch-universal/unsloth_kernels/rms_layernorm.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import triton
|
| 16 |
+
import triton.language as tl
|
| 17 |
+
import torch
|
| 18 |
+
from .utils import calculate_settings, torch_cuda_device
|
| 19 |
+
|
| 20 |
+
@triton.jit
|
| 21 |
+
def _rms_layernorm_forward(
|
| 22 |
+
Y, Y_row_stride,
|
| 23 |
+
X, X_row_stride,
|
| 24 |
+
W, W_row_stride,
|
| 25 |
+
r, r_row_stride : tl.constexpr,
|
| 26 |
+
n_cols : tl.constexpr,
|
| 27 |
+
eps : tl.constexpr,
|
| 28 |
+
BLOCK_SIZE : tl.constexpr,
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Fast RMS Layernorm kernel
|
| 32 |
+
Inspiration from a Triton tutorial:
|
| 33 |
+
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 34 |
+
"""
|
| 35 |
+
row_idx = tl.program_id(0)
|
| 36 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 37 |
+
mask = col_offsets < n_cols
|
| 38 |
+
|
| 39 |
+
Y += row_idx * Y_row_stride
|
| 40 |
+
X += row_idx * X_row_stride
|
| 41 |
+
r += row_idx * r_row_stride
|
| 42 |
+
|
| 43 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 44 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
|
| 45 |
+
|
| 46 |
+
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
| 47 |
+
inv_var = tl.math.rsqrt(row_var + eps)
|
| 48 |
+
tl.store(r, inv_var)
|
| 49 |
+
normed = X_row * inv_var
|
| 50 |
+
normed = normed.to(W_row.dtype) # Exact copy from HF
|
| 51 |
+
output = normed * W_row
|
| 52 |
+
tl.store(Y + col_offsets, output, mask = mask)
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _rms_layernorm_backward(
|
| 57 |
+
dY, dY_row_stride,
|
| 58 |
+
dX, dX_row_stride,
|
| 59 |
+
X, X_row_stride,
|
| 60 |
+
W, W_row_stride,
|
| 61 |
+
r, r_row_stride : tl.constexpr,
|
| 62 |
+
# dW, dW_row_stride,
|
| 63 |
+
n_cols : tl.constexpr,
|
| 64 |
+
eps : tl.constexpr,
|
| 65 |
+
GEMMA : tl.constexpr,
|
| 66 |
+
BLOCK_SIZE : tl.constexpr,
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
Fast RMS Layernorm kernel for the backward pass
|
| 70 |
+
Inspiration from a Triton tutorial:
|
| 71 |
+
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 72 |
+
"""
|
| 73 |
+
row_idx = tl.program_id(0)
|
| 74 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 75 |
+
mask = col_offsets < n_cols
|
| 76 |
+
|
| 77 |
+
dY += row_idx * dY_row_stride
|
| 78 |
+
X += row_idx * X_row_stride
|
| 79 |
+
r += row_idx * r_row_stride
|
| 80 |
+
|
| 81 |
+
if GEMMA: dX += row_idx * dY_row_stride
|
| 82 |
+
else: dX = dY
|
| 83 |
+
|
| 84 |
+
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 85 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 86 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 87 |
+
|
| 88 |
+
# Get saved row variance
|
| 89 |
+
inv_var = tl.load(r).to(tl.float32)
|
| 90 |
+
normed = X_row * inv_var
|
| 91 |
+
|
| 92 |
+
if GEMMA: dY_W = dY_row * (W_row + 1.0)
|
| 93 |
+
else: dY_W = dY_row * W_row
|
| 94 |
+
|
| 95 |
+
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
|
| 96 |
+
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
|
| 97 |
+
tl.store(dX + col_offsets, output, mask = mask)
|
| 98 |
+
pass
|
| 99 |
+
_rms_layernorm_backward = triton.jit(_rms_layernorm_backward)
|
| 100 |
+
_rms_layernorm_backward = triton.heuristics(
|
| 101 |
+
{
|
| 102 |
+
"GEMMA": lambda args: bool(args["GEMMA"]),
|
| 103 |
+
}
|
| 104 |
+
)(_rms_layernorm_backward)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@triton.jit
|
| 108 |
+
def _gemma_rms_layernorm_forward(
|
| 109 |
+
Y, Y_row_stride,
|
| 110 |
+
X, X_row_stride,
|
| 111 |
+
W, W_row_stride,
|
| 112 |
+
r, r_row_stride : tl.constexpr,
|
| 113 |
+
n_cols : tl.constexpr,
|
| 114 |
+
eps : tl.constexpr,
|
| 115 |
+
BLOCK_SIZE : tl.constexpr,
|
| 116 |
+
):
|
| 117 |
+
# Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
|
| 118 |
+
# and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
|
| 119 |
+
# exactly. Essentially all in float32!
|
| 120 |
+
row_idx = tl.program_id(0)
|
| 121 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 122 |
+
mask = col_offsets < n_cols
|
| 123 |
+
|
| 124 |
+
Y += row_idx * Y_row_stride
|
| 125 |
+
X += row_idx * X_row_stride
|
| 126 |
+
r += row_idx * r_row_stride
|
| 127 |
+
|
| 128 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 129 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
| 130 |
+
|
| 131 |
+
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
| 132 |
+
inv_var = tl.math.rsqrt(row_var + eps)
|
| 133 |
+
tl.store(r, inv_var)
|
| 134 |
+
normed = X_row * inv_var
|
| 135 |
+
output = normed * (W_row + 1.0)
|
| 136 |
+
|
| 137 |
+
tl.store(Y + col_offsets, output, mask = mask)
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class Fast_RMS_Layernorm(torch.autograd.Function):
|
| 142 |
+
@staticmethod
|
| 143 |
+
def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool = False):
|
| 144 |
+
shape = X.shape
|
| 145 |
+
dim : int = shape[-1]
|
| 146 |
+
X = X.view(-1, dim)
|
| 147 |
+
n_rows : int
|
| 148 |
+
n_cols : int
|
| 149 |
+
n_rows, n_cols = X.shape
|
| 150 |
+
BLOCK_SIZE : int
|
| 151 |
+
num_warps : int
|
| 152 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
| 153 |
+
device = X.device
|
| 154 |
+
|
| 155 |
+
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
|
| 156 |
+
r = torch.empty(n_rows, dtype = torch.float32, device = device)
|
| 157 |
+
|
| 158 |
+
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
|
| 159 |
+
with torch_cuda_device(device):
|
| 160 |
+
fx[(n_rows,)](
|
| 161 |
+
Y, Y.stride(0),
|
| 162 |
+
X, X.stride(0),
|
| 163 |
+
W, W.stride(0),
|
| 164 |
+
r, r.stride(0),
|
| 165 |
+
n_cols, eps,
|
| 166 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
| 167 |
+
num_warps = num_warps,
|
| 168 |
+
)
|
| 169 |
+
ctx.eps = eps
|
| 170 |
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
| 171 |
+
ctx.num_warps = num_warps
|
| 172 |
+
ctx.GEMMA = gemma
|
| 173 |
+
ctx.save_for_backward(X, W, r)
|
| 174 |
+
return Y.view(*shape)
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def backward(ctx, dY : torch.Tensor):
|
| 179 |
+
shape = dY.shape
|
| 180 |
+
dim : int = shape[-1]
|
| 181 |
+
dY = dY.view(-1, dim)
|
| 182 |
+
X, W, r = ctx.saved_tensors
|
| 183 |
+
n_rows : int
|
| 184 |
+
n_cols : int
|
| 185 |
+
n_rows, n_cols = dY.shape
|
| 186 |
+
# dW = X
|
| 187 |
+
dX = torch.empty_like(dY) if ctx.GEMMA else dY
|
| 188 |
+
|
| 189 |
+
with torch_cuda_device(dY.device):
|
| 190 |
+
_rms_layernorm_backward[(n_rows,)](
|
| 191 |
+
dY, dY.stride(0),
|
| 192 |
+
dX, dX.stride(0),
|
| 193 |
+
X, X .stride(0),
|
| 194 |
+
W, W .stride(0),
|
| 195 |
+
r, r .stride(0),
|
| 196 |
+
# dW, dW.stride(0),
|
| 197 |
+
n_cols, ctx.eps,
|
| 198 |
+
GEMMA = ctx.GEMMA,
|
| 199 |
+
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
| 200 |
+
num_warps = ctx.num_warps,
|
| 201 |
+
)
|
| 202 |
+
dX = dX.view(*shape)
|
| 203 |
+
return dX, None, None, None
|
| 204 |
+
pass
|
| 205 |
+
pass
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# [TODO] Unsure why RMS Layernorm is not torch.compiling properly
|
| 209 |
+
@torch.compiler.disable
|
| 210 |
+
def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False):
|
| 211 |
+
W : torch.Tensor = layernorm.weight
|
| 212 |
+
eps : float = layernorm.variance_epsilon if \
|
| 213 |
+
hasattr(layernorm, "variance_epsilon") \
|
| 214 |
+
else layernorm.eps
|
| 215 |
+
out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
|
| 216 |
+
return out
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
| 221 |
+
class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
|
| 222 |
+
def forward(self, X):
|
| 223 |
+
return fast_rms_layernorm(self, X, gemma = False)
|
| 224 |
+
pass
|
| 225 |
+
pass
|
| 226 |
+
|
| 227 |
+
try:
|
| 228 |
+
from transformers.models.mllama.modeling_mllama import MllamaTextRMSNorm
|
| 229 |
+
class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
|
| 230 |
+
def forward(self, X):
|
| 231 |
+
return fast_rms_layernorm(self, X, gemma = False)
|
| 232 |
+
pass
|
| 233 |
+
pass
|
| 234 |
+
except:
|
| 235 |
+
pass
|
| 236 |
+
pass
|
| 237 |
+
|
| 238 |
+
def patch_rms_layernorm():
|
| 239 |
+
import transformers.models.llama.modeling_llama
|
| 240 |
+
transformers.models.llama.modeling_llama.LlamaRMSNorm = Unsloth_LlamaRMSNorm
|
| 241 |
+
try:
|
| 242 |
+
import transformers.models.mllama.modeling_mllama
|
| 243 |
+
transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = Unsloth_MllamaTextRMSNorm
|
| 244 |
+
except:
|
| 245 |
+
pass
|
| 246 |
+
return
|
| 247 |
+
pass
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def unpatch_rms_layernorm():
|
| 251 |
+
import transformers.models.llama.modeling_llama
|
| 252 |
+
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
| 253 |
+
try:
|
| 254 |
+
import transformers.models.mllama.modeling_mllama
|
| 255 |
+
transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = MllamaTextRMSNorm
|
| 256 |
+
except:
|
| 257 |
+
pass
|
| 258 |
+
return
|
| 259 |
+
pass
|
| 260 |
+
|
| 261 |
+
|
build/torch-universal/unsloth_kernels/rope_embedding.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import triton
|
| 16 |
+
import triton.language as tl
|
| 17 |
+
import torch
|
| 18 |
+
from .utils import calculate_settings, torch_cuda_device
|
| 19 |
+
ROPE_GROUP_SIZE : int = 4
|
| 20 |
+
|
| 21 |
+
def _rope_embedding(
|
| 22 |
+
Q, Q_row_stride,
|
| 23 |
+
cos, cos_row_stride,
|
| 24 |
+
sin, sin_row_stride,
|
| 25 |
+
seqlen,
|
| 26 |
+
head_dim : tl.constexpr,
|
| 27 |
+
n_heads : tl.constexpr,
|
| 28 |
+
BACKWARD_PASS : tl.constexpr,
|
| 29 |
+
BLOCK_SIZE : tl.constexpr,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Calculates the RoPE Embedding quickly
|
| 33 |
+
RoPE is Q * cos + rotate_half(Q) * sin
|
| 34 |
+
See our blog post for more info
|
| 35 |
+
"""
|
| 36 |
+
ROPE_GROUP_SIZE = 4
|
| 37 |
+
row_position = tl.program_id(0)
|
| 38 |
+
group_head_position = tl.program_id(1)
|
| 39 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 40 |
+
half_head_dim = head_dim // 2
|
| 41 |
+
mask = col_offsets < half_head_dim
|
| 42 |
+
|
| 43 |
+
sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
|
| 44 |
+
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
| 45 |
+
cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
|
| 46 |
+
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
| 47 |
+
|
| 48 |
+
if BACKWARD_PASS:
|
| 49 |
+
# See our blog post for more info.
|
| 50 |
+
sin1 = -sin1
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
# [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
|
| 54 |
+
head_start = group_head_position * ROPE_GROUP_SIZE
|
| 55 |
+
head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)
|
| 56 |
+
|
| 57 |
+
# 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
|
| 58 |
+
for k in range(head_start, head_end):
|
| 59 |
+
offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
|
| 60 |
+
offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
|
| 61 |
+
|
| 62 |
+
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
|
| 63 |
+
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
|
| 64 |
+
Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
|
| 65 |
+
|
| 66 |
+
tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
|
| 67 |
+
tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
|
| 68 |
+
pass
|
| 69 |
+
pass
|
| 70 |
+
_rope_embedding = triton.jit(_rope_embedding)
|
| 71 |
+
_rope_embedding = triton.heuristics(
|
| 72 |
+
{
|
| 73 |
+
"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),
|
| 74 |
+
}
|
| 75 |
+
)(_rope_embedding)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Fast_RoPE_Embedding(torch.autograd.Function):
|
| 79 |
+
@staticmethod
|
| 80 |
+
def forward(ctx, Q, cos, sin):
|
| 81 |
+
cos, sin = cos.squeeze(), sin.squeeze()
|
| 82 |
+
batch : int
|
| 83 |
+
seq_len : int
|
| 84 |
+
n_heads : int
|
| 85 |
+
head_dim : int
|
| 86 |
+
batch, seq_len, n_heads, head_dim = Q.shape
|
| 87 |
+
Q = Q.view(batch*seq_len, n_heads*head_dim)
|
| 88 |
+
n_rows : int
|
| 89 |
+
n_cols : int
|
| 90 |
+
n_rows, n_cols = Q.shape
|
| 91 |
+
assert(seq_len <= cos.shape[0])
|
| 92 |
+
|
| 93 |
+
# [TODO] Changing blocksize to head_dim//2 seems to have
|
| 94 |
+
# some concurrency / un-deterministic issues.
|
| 95 |
+
BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
|
| 96 |
+
|
| 97 |
+
# group_size = 4 # 4 or 8, too large group_size can hurt performance.
|
| 98 |
+
div : int
|
| 99 |
+
mod : int
|
| 100 |
+
div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
|
| 101 |
+
n_groups : int = div + (mod != 0)
|
| 102 |
+
|
| 103 |
+
with torch_cuda_device(Q.device):
|
| 104 |
+
_rope_embedding[(n_rows, n_groups, )](
|
| 105 |
+
Q, Q.stride(0),
|
| 106 |
+
cos, cos.stride(0),
|
| 107 |
+
sin, sin.stride(0),
|
| 108 |
+
seq_len,
|
| 109 |
+
head_dim, n_heads,
|
| 110 |
+
BACKWARD_PASS = False,
|
| 111 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
| 112 |
+
num_warps = num_warps,
|
| 113 |
+
)
|
| 114 |
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
| 115 |
+
ctx.num_warps = num_warps
|
| 116 |
+
ctx.n_groups = n_groups
|
| 117 |
+
ctx.cos = cos
|
| 118 |
+
ctx.sin = sin
|
| 119 |
+
return Q.view(batch, seq_len, n_heads, head_dim)
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def backward(ctx, dY):
|
| 124 |
+
batch : int
|
| 125 |
+
seq_len : int
|
| 126 |
+
n_heads : int
|
| 127 |
+
head_dim : int
|
| 128 |
+
batch, seq_len, n_heads, head_dim = dY.shape
|
| 129 |
+
dY = dY.reshape(batch*seq_len, n_heads*head_dim)
|
| 130 |
+
# Must be reshape not view
|
| 131 |
+
n_rows : int
|
| 132 |
+
n_cols : int
|
| 133 |
+
n_rows, n_cols = dY.shape
|
| 134 |
+
|
| 135 |
+
cos = ctx.cos
|
| 136 |
+
sin = ctx.sin
|
| 137 |
+
|
| 138 |
+
with torch_cuda_device(dY.device):
|
| 139 |
+
_rope_embedding[(n_rows, ctx.n_groups, )](
|
| 140 |
+
dY, dY .stride(0),
|
| 141 |
+
cos, cos.stride(0),
|
| 142 |
+
sin, sin.stride(0),
|
| 143 |
+
seq_len, head_dim, n_heads,
|
| 144 |
+
BACKWARD_PASS = True,
|
| 145 |
+
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
| 146 |
+
num_warps = ctx.num_warps,
|
| 147 |
+
)
|
| 148 |
+
dY = dY.view(batch, seq_len, n_heads, head_dim)
|
| 149 |
+
return dY, None, None,
|
| 150 |
+
pass
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
# [TODO] Unsure why RoPE Embedding is not torch.compiling properly
|
| 154 |
+
@torch.compiler.disable
|
| 155 |
+
def fast_rope_embedding(Q, K, cos, sin):
|
| 156 |
+
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
|
| 157 |
+
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
|
| 158 |
+
return Q, K
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class Slow_RoPE_Embedding(torch.autograd.Function):
|
| 163 |
+
@staticmethod
|
| 164 |
+
def forward(ctx, Q, cos, sin, position_ids):
|
| 165 |
+
if position_ids is not None:
|
| 166 |
+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
| 167 |
+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 168 |
+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 169 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 170 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 171 |
+
|
| 172 |
+
# Q * cos + rotate_half(Q) * sin
|
| 173 |
+
half = Q.shape[-1]//2
|
| 174 |
+
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
|
| 175 |
+
Q *= cos
|
| 176 |
+
Q.addcmul_(RH_Q, sin)
|
| 177 |
+
# RH_Q *= sin
|
| 178 |
+
# Q += RH_Q
|
| 179 |
+
ctx.save_for_backward(cos, sin)
|
| 180 |
+
return Q
|
| 181 |
+
pass
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def backward(ctx, dY):
|
| 185 |
+
cos, sin = ctx.saved_tensors
|
| 186 |
+
# Q * cos + rotate_half.T(Q) * sin
|
| 187 |
+
half = dY.shape[-1]//2
|
| 188 |
+
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
|
| 189 |
+
dY *= cos
|
| 190 |
+
dY.addcmul_(RH_dY, sin)
|
| 191 |
+
# RH_dY *= sin
|
| 192 |
+
# dY += RH_dY
|
| 193 |
+
return dY, None, None, None
|
| 194 |
+
pass
|
| 195 |
+
pass
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def inplace_rope_embedding(Q, K, cos, sin, position_ids):
|
| 199 |
+
Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
|
| 200 |
+
K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
|
| 201 |
+
return Q, K
|
| 202 |
+
pass
|
build/torch-universal/unsloth_kernels/swiglu.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import triton
|
| 16 |
+
import triton.language as tl
|
| 17 |
+
import torch
|
| 18 |
+
from .utils import calculate_settings, torch_cuda_device
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@triton.jit
|
| 22 |
+
def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
| 23 |
+
block_idx = tl.program_id(0)
|
| 24 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 25 |
+
mask = offsets < n_elements
|
| 26 |
+
|
| 27 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
| 28 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
| 29 |
+
|
| 30 |
+
# f = e * sigmoid(e)
|
| 31 |
+
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
|
| 32 |
+
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
| 33 |
+
# h = f * g
|
| 34 |
+
h_row = f_row * g_row
|
| 35 |
+
|
| 36 |
+
# Store h
|
| 37 |
+
tl.store(h + offsets, h_row, mask = mask)
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def swiglu_fg_kernel(e, g):
|
| 42 |
+
batch, seq_len, hd = e.shape
|
| 43 |
+
n_elements = e.numel()
|
| 44 |
+
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)
|
| 45 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
| 46 |
+
with torch_cuda_device(e.device):
|
| 47 |
+
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
|
| 48 |
+
return h
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@triton.jit
|
| 53 |
+
def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
| 54 |
+
"""
|
| 55 |
+
e = e.float()
|
| 56 |
+
se = 1.0 / (1.0 + torch.exp(-e))
|
| 57 |
+
f = (se * e).to(dtype)
|
| 58 |
+
h = f * g
|
| 59 |
+
df = DW * f
|
| 60 |
+
dg = DW * g
|
| 61 |
+
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
|
| 62 |
+
"""
|
| 63 |
+
block_idx = tl.program_id(0)
|
| 64 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 65 |
+
mask = offsets < n_elements
|
| 66 |
+
|
| 67 |
+
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
| 68 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
| 69 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
| 70 |
+
|
| 71 |
+
# e = e.float()
|
| 72 |
+
# se = 1.0 / (1.0 + torch.exp(-e))
|
| 73 |
+
se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
|
| 74 |
+
# f = (se * e).to(dtype)
|
| 75 |
+
f_row = se_row * e_row
|
| 76 |
+
f_row = f_row.to(DW_row.dtype)
|
| 77 |
+
# h = f * g
|
| 78 |
+
h_row = f_row * g_row
|
| 79 |
+
# df = DW * f
|
| 80 |
+
df_row = DW_row * f_row
|
| 81 |
+
# dg = DW * g
|
| 82 |
+
dg_row = DW_row * g_row
|
| 83 |
+
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
|
| 84 |
+
de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
|
| 85 |
+
de_row = de_row.to(DW_row.dtype)
|
| 86 |
+
|
| 87 |
+
# Store derivatives in buffers
|
| 88 |
+
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
| 89 |
+
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
| 90 |
+
tl.store(g + offsets, de_row, mask = mask) # de
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
|
| 95 |
+
batch_seq_len, hd = e.shape
|
| 96 |
+
n_elements = e.numel()
|
| 97 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
| 98 |
+
with torch_cuda_device(e.device):
|
| 99 |
+
_DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
| 100 |
+
return DW, e, g
|
| 101 |
+
pass
|
build/torch-universal/unsloth_kernels/utils.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import triton
|
| 16 |
+
MAX_FUSED_SIZE : int = 65536
|
| 17 |
+
next_power_of_2 = triton.next_power_of_2
|
| 18 |
+
import functools
|
| 19 |
+
|
| 20 |
+
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
|
| 21 |
+
import torch
|
| 22 |
+
torch_Tensor = torch.Tensor
|
| 23 |
+
from packaging.version import Version
|
| 24 |
+
if Version(torch.__version__) < Version("2.4.0"):
|
| 25 |
+
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
| 26 |
+
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
| 27 |
+
else:
|
| 28 |
+
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
|
| 29 |
+
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# tl.math.tanh now is libdevice.tanh
|
| 34 |
+
from packaging.version import Version
|
| 35 |
+
import triton
|
| 36 |
+
import triton.language as tl
|
| 37 |
+
if Version(triton.__version__) >= Version("3.0.0"):
|
| 38 |
+
from triton.language.extra import libdevice
|
| 39 |
+
triton_tanh = libdevice.tanh
|
| 40 |
+
triton_cast = tl.cast
|
| 41 |
+
else:
|
| 42 |
+
triton_tanh = tl.math.tanh
|
| 43 |
+
# No casting in old Triton versions
|
| 44 |
+
@triton.jit
|
| 45 |
+
def triton_cast(x, dtype):
|
| 46 |
+
return x.to(dtype)
|
| 47 |
+
pass
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def calculate_settings(n : int) -> (int, int,):
|
| 52 |
+
BLOCK_SIZE : int = next_power_of_2(n)
|
| 53 |
+
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
| 54 |
+
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
| 55 |
+
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
| 56 |
+
num_warps : int = 4
|
| 57 |
+
if BLOCK_SIZE >= 32768: num_warps = 32
|
| 58 |
+
elif BLOCK_SIZE >= 8192: num_warps = 16
|
| 59 |
+
elif BLOCK_SIZE >= 2048: num_warps = 8
|
| 60 |
+
return BLOCK_SIZE, num_warps
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
import bitsandbytes as bnb
|
| 65 |
+
import ctypes
|
| 66 |
+
|
| 67 |
+
# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
|
| 68 |
+
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
|
| 69 |
+
get_ptr = bnb.functional.get_ptr
|
| 70 |
+
|
| 71 |
+
if torch.cuda.device_count() > 1:
|
| 72 |
+
torch_cuda_device = torch.cuda.device
|
| 73 |
+
else:
|
| 74 |
+
from contextlib import nullcontext
|
| 75 |
+
def torch_cuda_device(device): return nullcontext()
|
| 76 |
+
pass
|
| 77 |
+
_cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream
|
| 78 |
+
c_void_p = ctypes.c_void_p
|
| 79 |
+
def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
|
| 80 |
+
return c_void_p(_cuda_getCurrentRawStream(tensor.device.index))
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
# Get array of CUDA streams and other buffers
|
| 84 |
+
global CUDA_STREAMS
|
| 85 |
+
global WEIGHT_BUFFERS
|
| 86 |
+
global ABSMAX_BUFFERS
|
| 87 |
+
|
| 88 |
+
_CUDA_STREAMS = {
|
| 89 |
+
(index := torch.cuda.device(i).idx) : ctypes.c_void_p(torch._C._cuda_getCurrentRawStream(index))
|
| 90 |
+
for i in range(torch.cuda.device_count())
|
| 91 |
+
}
|
| 92 |
+
CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
|
| 93 |
+
WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
|
| 94 |
+
ABSMAX_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
|
| 95 |
+
for k, v in _CUDA_STREAMS.items(): CUDA_STREAMS[k] = v
|
| 96 |
+
CUDA_STREAMS = tuple(CUDA_STREAMS)
|
| 97 |
+
del _CUDA_STREAMS
|
| 98 |
+
|
| 99 |
+
# Bitsandbytes operations
|
| 100 |
+
ctypes_c_int = ctypes.c_int
|
| 101 |
+
ctypes_c_int32 = ctypes.c_int32
|
| 102 |
+
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
| 103 |
+
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
| 104 |
+
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
| 105 |
+
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
|
| 106 |
+
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
|
| 107 |
+
torch_mm = torch.mm
|
| 108 |
+
torch_mv = torch.mv
|
| 109 |
+
torch_matmul = torch.matmul
|
| 110 |
+
torch_addmm = torch.addmm
|
| 111 |
+
torch_empty = torch.empty
|
| 112 |
+
|
| 113 |
+
def QUANT_STATE(W): return getattr(W, "quant_state", None)
|
| 114 |
+
|
| 115 |
+
def get_lora_parameters(proj):
|
| 116 |
+
# For DPO or disabled adapters
|
| 117 |
+
base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
| 118 |
+
W = base_layer.weight
|
| 119 |
+
|
| 120 |
+
# if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
| 121 |
+
if getattr(proj, "disable_adapters", True) or proj.merged:
|
| 122 |
+
return W, getattr(W, "quant_state", None), None, None, None
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
adapter = getattr(proj, "active_adapters", None)
|
| 126 |
+
if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
|
| 127 |
+
adapter = adapter[0]
|
| 128 |
+
|
| 129 |
+
return (
|
| 130 |
+
W,
|
| 131 |
+
getattr(W, "quant_state", None),
|
| 132 |
+
proj.lora_A [adapter].weight,
|
| 133 |
+
proj.lora_B [adapter].weight,
|
| 134 |
+
proj.scaling[adapter],
|
| 135 |
+
)
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_lora_parameters_bias(proj):
|
| 140 |
+
# For DPO or disabled adapters
|
| 141 |
+
base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
| 142 |
+
W = base_layer.weight
|
| 143 |
+
|
| 144 |
+
# if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
| 145 |
+
if getattr(proj, "disable_adapters", True) or proj.merged:
|
| 146 |
+
return W, getattr(W, "quant_state", None), None, None, None, base_layer.bias
|
| 147 |
+
pass
|
| 148 |
+
|
| 149 |
+
adapter = getattr(proj, "active_adapters", None)
|
| 150 |
+
if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
|
| 151 |
+
adapter = adapter[0]
|
| 152 |
+
|
| 153 |
+
return (
|
| 154 |
+
W,
|
| 155 |
+
getattr(W, "quant_state", None),
|
| 156 |
+
proj.lora_A [adapter].weight,
|
| 157 |
+
proj.lora_B [adapter].weight,
|
| 158 |
+
proj.scaling[adapter],
|
| 159 |
+
base_layer.bias,
|
| 160 |
+
)
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
if HAS_CUDA_STREAM:
|
| 164 |
+
@torch.inference_mode
|
| 165 |
+
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
|
| 166 |
+
if quant_state is None: return W
|
| 167 |
+
if type(quant_state) is not list:
|
| 168 |
+
# New quant_state as a class
|
| 169 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
| 170 |
+
absmax = quant_state.absmax
|
| 171 |
+
shape = quant_state.shape
|
| 172 |
+
dtype = quant_state.dtype
|
| 173 |
+
blocksize = quant_state.blocksize
|
| 174 |
+
offset = quant_state.offset
|
| 175 |
+
state2 = quant_state.state2
|
| 176 |
+
absmax2 = state2.absmax
|
| 177 |
+
code2 = state2.code
|
| 178 |
+
blocksize2 = state2.blocksize
|
| 179 |
+
else:
|
| 180 |
+
# Old quant_state as a list of lists
|
| 181 |
+
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
| 182 |
+
offset, state2 = compressed_stats
|
| 183 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
| 184 |
+
pass
|
| 185 |
+
global CUDA_STREAMS
|
| 186 |
+
device = W.device
|
| 187 |
+
device_index = device.index
|
| 188 |
+
CUDA_STREAM = CUDA_STREAMS[device_index]
|
| 189 |
+
|
| 190 |
+
n_elements_absmax = absmax.numel()
|
| 191 |
+
|
| 192 |
+
# Create weight matrix
|
| 193 |
+
if use_global_buffer:
|
| 194 |
+
|
| 195 |
+
# Use same buffers for faster inference
|
| 196 |
+
size = shape[0]*shape[1]
|
| 197 |
+
global WEIGHT_BUFFERS
|
| 198 |
+
global ABSMAX_BUFFERS
|
| 199 |
+
WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
|
| 200 |
+
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
|
| 201 |
+
if WEIGHT_BUFFER is None:
|
| 202 |
+
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(size, dtype = dtype, device = device, requires_grad = False)
|
| 203 |
+
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
|
| 204 |
+
|
| 205 |
+
if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
|
| 206 |
+
if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)
|
| 207 |
+
|
| 208 |
+
out = WEIGHT_BUFFER[:size].view(shape)
|
| 209 |
+
out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
|
| 210 |
+
else:
|
| 211 |
+
if out is None:
|
| 212 |
+
out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
|
| 213 |
+
else:
|
| 214 |
+
assert(out.shape == shape)
|
| 215 |
+
assert(out.dtype == dtype)
|
| 216 |
+
out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
# NF4 dequantization of statistics
|
| 220 |
+
ptr_out_absmax = get_ptr(out_absmax)
|
| 221 |
+
with torch_cuda_device(device):
|
| 222 |
+
cdequantize_blockwise_fp32(
|
| 223 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
|
| 224 |
+
ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM
|
| 225 |
+
)
|
| 226 |
+
out_absmax += offset
|
| 227 |
+
|
| 228 |
+
# Dequantize W
|
| 229 |
+
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
|
| 230 |
+
cdequantize_blockwise_bf16_nf4
|
| 231 |
+
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
|
| 232 |
+
ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,)
|
| 233 |
+
pass
|
| 234 |
+
# Careful returning transposed data
|
| 235 |
+
is_transposed = (True if W.shape[0] == 1 else False)
|
| 236 |
+
return out.t() if is_transposed else out
|
| 237 |
+
pass
|
| 238 |
+
else:
|
| 239 |
+
@torch.inference_mode
|
| 240 |
+
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
|
| 241 |
+
if quant_state is None: return W
|
| 242 |
+
if type(quant_state) is not list:
|
| 243 |
+
# New quant_state as a class
|
| 244 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
| 245 |
+
absmax = quant_state.absmax
|
| 246 |
+
shape = quant_state.shape
|
| 247 |
+
dtype = quant_state.dtype
|
| 248 |
+
blocksize = quant_state.blocksize
|
| 249 |
+
offset = quant_state.offset
|
| 250 |
+
state2 = quant_state.state2
|
| 251 |
+
absmax2 = state2.absmax
|
| 252 |
+
code2 = state2.code
|
| 253 |
+
blocksize2 = state2.blocksize
|
| 254 |
+
else:
|
| 255 |
+
# Old quant_state as a list of lists
|
| 256 |
+
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
| 257 |
+
offset, state2 = compressed_stats
|
| 258 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
| 259 |
+
pass
|
| 260 |
+
|
| 261 |
+
n_elements_absmax = absmax.numel()
|
| 262 |
+
device = W.device
|
| 263 |
+
|
| 264 |
+
# Create weight matrix
|
| 265 |
+
if out is None:
|
| 266 |
+
out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
|
| 267 |
+
else:
|
| 268 |
+
assert(out.shape == shape)
|
| 269 |
+
assert(out.dtype == dtype)
|
| 270 |
+
out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
|
| 271 |
+
|
| 272 |
+
# Do dequantization
|
| 273 |
+
ptr_out_absmax = get_ptr(out_absmax)
|
| 274 |
+
cdequantize_blockwise_fp32(
|
| 275 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
|
| 276 |
+
ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax),
|
| 277 |
+
)
|
| 278 |
+
out_absmax += offset
|
| 279 |
+
|
| 280 |
+
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
|
| 281 |
+
cdequantize_blockwise_bf16_nf4
|
| 282 |
+
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
|
| 283 |
+
ctypes_c_int(blocksize), ctypes_c_int(out.numel()),)
|
| 284 |
+
|
| 285 |
+
# Careful returning transposed data
|
| 286 |
+
is_transposed = (True if W.shape[0] == 1 else False)
|
| 287 |
+
return out.t() if is_transposed else out
|
| 288 |
+
pass
|
| 289 |
+
pass
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if HAS_CUDA_STREAM:
|
| 293 |
+
def fast_gemv(X, W, quant_state, out = None):
|
| 294 |
+
if quant_state is None: return torch_matmul(X, W, out = out)
|
| 295 |
+
# For fast X @ W where seq_len == 1
|
| 296 |
+
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
| 297 |
+
_, q_len, hd = X.shape
|
| 298 |
+
# assert(q_len == 1)
|
| 299 |
+
|
| 300 |
+
if type(quant_state) is not list:
|
| 301 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
| 302 |
+
absmax = quant_state.absmax
|
| 303 |
+
shape = quant_state.shape
|
| 304 |
+
dtype = quant_state.dtype
|
| 305 |
+
blocksize = quant_state.blocksize
|
| 306 |
+
stats = quant_state.code
|
| 307 |
+
offset = quant_state.offset
|
| 308 |
+
state2 = quant_state.state2
|
| 309 |
+
absmax2 = state2.absmax
|
| 310 |
+
code2 = state2.code
|
| 311 |
+
blocksize2 = state2.blocksize
|
| 312 |
+
else:
|
| 313 |
+
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
|
| 314 |
+
offset, state2 = compressed_stats
|
| 315 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
| 316 |
+
pass
|
| 317 |
+
global CUDA_STREAMS
|
| 318 |
+
device = W.device
|
| 319 |
+
device_index = device.index
|
| 320 |
+
CUDA_STREAM = CUDA_STREAMS[device_index]
|
| 321 |
+
|
| 322 |
+
# assert(dtype == X.dtype)
|
| 323 |
+
bout = shape[0]
|
| 324 |
+
|
| 325 |
+
if out is None:
|
| 326 |
+
out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
|
| 327 |
+
# else:
|
| 328 |
+
# assert(out.shape == (1, 1, bout,))
|
| 329 |
+
# pass
|
| 330 |
+
|
| 331 |
+
n = 1
|
| 332 |
+
m = shape[0]
|
| 333 |
+
k = shape[1]
|
| 334 |
+
lda = shape[0]
|
| 335 |
+
ldc = shape[0]
|
| 336 |
+
ldb = (hd+1)//2
|
| 337 |
+
m = ctypes_c_int32(m)
|
| 338 |
+
n = ctypes_c_int32(n)
|
| 339 |
+
k = ctypes_c_int32(k)
|
| 340 |
+
lda = ctypes_c_int32(lda)
|
| 341 |
+
ldb = ctypes_c_int32(ldb)
|
| 342 |
+
ldc = ctypes_c_int32(ldc)
|
| 343 |
+
|
| 344 |
+
df = torch_empty(absmax.shape, dtype = torch.float32, device = device)
|
| 345 |
+
with torch_cuda_device(device):
|
| 346 |
+
cdequantize_blockwise_fp32(
|
| 347 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
|
| 348 |
+
ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM,
|
| 349 |
+
)
|
| 350 |
+
df += offset
|
| 351 |
+
absmax = df
|
| 352 |
+
|
| 353 |
+
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
|
| 354 |
+
cgemm_4bit_inference_naive_bf16
|
| 355 |
+
|
| 356 |
+
blocksize = ctypes_c_int32(blocksize)
|
| 357 |
+
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
|
| 358 |
+
lda, ldb, ldc, blocksize, CUDA_STREAM,)
|
| 359 |
+
pass
|
| 360 |
+
|
| 361 |
+
return out
|
| 362 |
+
pass
|
| 363 |
+
else:
|
| 364 |
+
def fast_gemv(X, W, quant_state, out = None):
|
| 365 |
+
if quant_state is None: return torch.matmul(X, W, out = out)
|
| 366 |
+
# For fast X @ W where seq_len == 1
|
| 367 |
+
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
| 368 |
+
_, q_len, hd = X.shape
|
| 369 |
+
# assert(q_len == 1)
|
| 370 |
+
|
| 371 |
+
if type(quant_state) is not list:
|
| 372 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
| 373 |
+
absmax = quant_state.absmax
|
| 374 |
+
shape = quant_state.shape
|
| 375 |
+
dtype = quant_state.dtype
|
| 376 |
+
blocksize = quant_state.blocksize
|
| 377 |
+
stats = quant_state.code
|
| 378 |
+
offset = quant_state.offset
|
| 379 |
+
state2 = quant_state.state2
|
| 380 |
+
absmax2 = state2.absmax
|
| 381 |
+
code2 = state2.code
|
| 382 |
+
blocksize2 = state2.blocksize
|
| 383 |
+
else:
|
| 384 |
+
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
|
| 385 |
+
offset, state2 = compressed_stats
|
| 386 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
| 387 |
+
pass
|
| 388 |
+
# assert(dtype == X.dtype)
|
| 389 |
+
bout = shape[0]
|
| 390 |
+
device = W.device
|
| 391 |
+
|
| 392 |
+
if out is None:
|
| 393 |
+
out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
|
| 394 |
+
# else:
|
| 395 |
+
# assert(out.shape == (1, 1, bout,))
|
| 396 |
+
# pass
|
| 397 |
+
|
| 398 |
+
n = 1
|
| 399 |
+
m = shape[0]
|
| 400 |
+
k = shape[1]
|
| 401 |
+
lda = shape[0]
|
| 402 |
+
ldc = shape[0]
|
| 403 |
+
ldb = (hd+1)//2
|
| 404 |
+
m = ctypes_c_int32(m)
|
| 405 |
+
n = ctypes_c_int32(n)
|
| 406 |
+
k = ctypes_c_int32(k)
|
| 407 |
+
lda = ctypes_c_int32(lda)
|
| 408 |
+
ldb = ctypes_c_int32(ldb)
|
| 409 |
+
ldc = ctypes_c_int32(ldc)
|
| 410 |
+
|
| 411 |
+
df = torch_empty(absmax.shape, dtype = torch.float32, device = device)
|
| 412 |
+
cdequantize_blockwise_fp32(
|
| 413 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
|
| 414 |
+
ctypes_c_int(blocksize2), ctypes_c_int(df.numel()),
|
| 415 |
+
)
|
| 416 |
+
df += offset
|
| 417 |
+
absmax = df
|
| 418 |
+
|
| 419 |
+
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
|
| 420 |
+
cgemm_4bit_inference_naive_bf16
|
| 421 |
+
|
| 422 |
+
blocksize = ctypes_c_int32(blocksize)
|
| 423 |
+
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
|
| 424 |
+
lda, ldb, ldc, blocksize,)
|
| 425 |
+
|
| 426 |
+
return out
|
| 427 |
+
pass
|
| 428 |
+
pass
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def fast_linear_forward(proj, X, temp_lora = None, out = None):
|
| 432 |
+
|
| 433 |
+
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
|
| 434 |
+
bsz, q_len, in_dim = X.shape
|
| 435 |
+
if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
|
| 436 |
+
|
| 437 |
+
if W_quant is None:
|
| 438 |
+
out = torch_matmul(X, W.t(), out = out)
|
| 439 |
+
elif bsz == 1 and q_len == 1:
|
| 440 |
+
out = fast_gemv(X, W, W_quant, out = out)
|
| 441 |
+
else:
|
| 442 |
+
W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
|
| 443 |
+
out = torch_matmul(X, W, out = out)
|
| 444 |
+
pass
|
| 445 |
+
|
| 446 |
+
# Add in LoRA weights
|
| 447 |
+
if lora_A is not None:
|
| 448 |
+
out_dim = out.shape[2]
|
| 449 |
+
dtype = X.dtype
|
| 450 |
+
|
| 451 |
+
if not hasattr(lora_A, "_fast_lora"):
|
| 452 |
+
lora_A._fast_lora = lora_A.to(dtype)
|
| 453 |
+
lora_B._fast_lora = lora_B.to(dtype)
|
| 454 |
+
pass
|
| 455 |
+
|
| 456 |
+
if bsz == 1:
|
| 457 |
+
out = out.view(out_dim)
|
| 458 |
+
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
|
| 459 |
+
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
|
| 460 |
+
else:
|
| 461 |
+
out = out.view(bsz, out_dim)
|
| 462 |
+
temp_lora = torch_mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
|
| 463 |
+
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
|
| 464 |
+
pass
|
| 465 |
+
out = out.view(bsz, 1, out_dim)
|
| 466 |
+
pass
|
| 467 |
+
|
| 468 |
+
if bias is not None: out += bias
|
| 469 |
+
|
| 470 |
+
return out
|
| 471 |
+
pass
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def matmul_lora(X, W, W_quant, A, B, s, out = None):
|
| 475 |
+
dtype = X.dtype
|
| 476 |
+
W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
|
| 477 |
+
|
| 478 |
+
if X.dim() == 3:
|
| 479 |
+
batch, seq_len, d = X.shape
|
| 480 |
+
X = X.view(-1, X.shape[-1])
|
| 481 |
+
reshape = True
|
| 482 |
+
else:
|
| 483 |
+
reshape = False
|
| 484 |
+
pass
|
| 485 |
+
out = torch_matmul(X, W, out = out)
|
| 486 |
+
if W_quant is not None: del W
|
| 487 |
+
|
| 488 |
+
if A is not None:
|
| 489 |
+
# LoRA is enabled
|
| 490 |
+
A, B = A.t(), B.t()
|
| 491 |
+
XA = torch_matmul(X, A.to(dtype))
|
| 492 |
+
out.addmm_(XA, B.to(dtype), alpha = s)
|
| 493 |
+
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
| 494 |
+
pass
|
| 495 |
+
|
| 496 |
+
return out.view(batch, seq_len, -1) if reshape else out
|
| 497 |
+
pass
|
flake.lock
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nodes": {
|
| 3 |
+
"flake-compat": {
|
| 4 |
+
"locked": {
|
| 5 |
+
"lastModified": 1733328505,
|
| 6 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
| 7 |
+
"owner": "edolstra",
|
| 8 |
+
"repo": "flake-compat",
|
| 9 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
| 10 |
+
"type": "github"
|
| 11 |
+
},
|
| 12 |
+
"original": {
|
| 13 |
+
"owner": "edolstra",
|
| 14 |
+
"repo": "flake-compat",
|
| 15 |
+
"type": "github"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"flake-utils": {
|
| 19 |
+
"inputs": {
|
| 20 |
+
"systems": "systems"
|
| 21 |
+
},
|
| 22 |
+
"locked": {
|
| 23 |
+
"lastModified": 1731533236,
|
| 24 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 25 |
+
"owner": "numtide",
|
| 26 |
+
"repo": "flake-utils",
|
| 27 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 28 |
+
"type": "github"
|
| 29 |
+
},
|
| 30 |
+
"original": {
|
| 31 |
+
"owner": "numtide",
|
| 32 |
+
"repo": "flake-utils",
|
| 33 |
+
"type": "github"
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"kernel-builder": {
|
| 37 |
+
"inputs": {
|
| 38 |
+
"flake-compat": "flake-compat",
|
| 39 |
+
"flake-utils": "flake-utils",
|
| 40 |
+
"nixpkgs": "nixpkgs",
|
| 41 |
+
"rocm-nix": "rocm-nix"
|
| 42 |
+
},
|
| 43 |
+
"locked": {
|
| 44 |
+
"lastModified": 1745579622,
|
| 45 |
+
"narHash": "sha256-g8BXijChxDCZNu17M4Jj0GPv/7faVnArbHBOMNMpHjM=",
|
| 46 |
+
"owner": "huggingface",
|
| 47 |
+
"repo": "kernel-builder",
|
| 48 |
+
"rev": "e2f6f338737c6f1c570f9b59e43182633c0879c1",
|
| 49 |
+
"type": "github"
|
| 50 |
+
},
|
| 51 |
+
"original": {
|
| 52 |
+
"owner": "huggingface",
|
| 53 |
+
"repo": "kernel-builder",
|
| 54 |
+
"type": "github"
|
| 55 |
+
}
|
| 56 |
+
},
|
| 57 |
+
"nixpkgs": {
|
| 58 |
+
"locked": {
|
| 59 |
+
"lastModified": 1743559129,
|
| 60 |
+
"narHash": "sha256-7gpAWsENV3tY2HmeHYQ2MoQxGpys+jQWnkS/BHAMXVk=",
|
| 61 |
+
"owner": "nixos",
|
| 62 |
+
"repo": "nixpkgs",
|
| 63 |
+
"rev": "adae22bea8bcc0aa2fd6e8732044660fb7755f5e",
|
| 64 |
+
"type": "github"
|
| 65 |
+
},
|
| 66 |
+
"original": {
|
| 67 |
+
"owner": "nixos",
|
| 68 |
+
"ref": "nixos-unstable-small",
|
| 69 |
+
"repo": "nixpkgs",
|
| 70 |
+
"type": "github"
|
| 71 |
+
}
|
| 72 |
+
},
|
| 73 |
+
"rocm-nix": {
|
| 74 |
+
"inputs": {
|
| 75 |
+
"nixpkgs": [
|
| 76 |
+
"kernel-builder",
|
| 77 |
+
"nixpkgs"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
"locked": {
|
| 81 |
+
"lastModified": 1745310663,
|
| 82 |
+
"narHash": "sha256-1U3PzCO/jt7HUlEgLOY3RpxadKwTo6GSvb2j4m0UFw0=",
|
| 83 |
+
"owner": "huggingface",
|
| 84 |
+
"repo": "rocm-nix",
|
| 85 |
+
"rev": "e08373a0efa1c297b0c57af070e0a311df47481f",
|
| 86 |
+
"type": "github"
|
| 87 |
+
},
|
| 88 |
+
"original": {
|
| 89 |
+
"owner": "huggingface",
|
| 90 |
+
"repo": "rocm-nix",
|
| 91 |
+
"type": "github"
|
| 92 |
+
}
|
| 93 |
+
},
|
| 94 |
+
"root": {
|
| 95 |
+
"inputs": {
|
| 96 |
+
"kernel-builder": "kernel-builder"
|
| 97 |
+
}
|
| 98 |
+
},
|
| 99 |
+
"systems": {
|
| 100 |
+
"locked": {
|
| 101 |
+
"lastModified": 1681028828,
|
| 102 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 103 |
+
"owner": "nix-systems",
|
| 104 |
+
"repo": "default",
|
| 105 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 106 |
+
"type": "github"
|
| 107 |
+
},
|
| 108 |
+
"original": {
|
| 109 |
+
"owner": "nix-systems",
|
| 110 |
+
"repo": "default",
|
| 111 |
+
"type": "github"
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
},
|
| 115 |
+
"root": "root",
|
| 116 |
+
"version": 7
|
| 117 |
+
}
|
flake.nix
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
{
|
| 2 |
-
description = "Flake for
|
| 3 |
|
| 4 |
inputs = {
|
| 5 |
-
kernel-builder.url = "
|
| 6 |
};
|
| 7 |
|
| 8 |
outputs =
|
|
|
|
| 1 |
{
|
| 2 |
+
description = "Flake for Unsloth Kernels";
|
| 3 |
|
| 4 |
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
};
|
| 7 |
|
| 8 |
outputs =
|