| |
| deleted file mode 100644 |
| |
| |
| |
| @@ -1,187 +0,0 @@ |
| -#include <stdio.h> |
| -#include <assert.h> |
| - |
| -#define MIN_VALUE (-1e38) |
| - |
| -template <typename F> |
| -__global__ void kernel_forward( |
| - const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, |
| - const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y |
| -) { |
| - const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| - const int _b = idx / C; |
| - const int _c = idx % C; |
| - const int _offset = _b * T * C + _c; |
| - |
| - F u = _u[_c]; |
| - F w = _w[_c]; |
| - const F *__restrict__ const k = _k + _offset; |
| - const F *__restrict__ const v = _v + _offset; |
| - F *__restrict__ const y = _y + _offset; |
| - |
| - // aa and bb are running sums divided by exp(pp) (to avoid overflow) |
| - F aa = 0, bb = 0, pp = MIN_VALUE; |
| - for (int i = 0; i < T; i++) { |
| - const int ii = i * C; |
| - const F kk = k[ii]; |
| - const F vv = v[ii]; |
| - |
| - F ww = u + kk; |
| - F p = max(pp, ww); |
| - F e1 = exp(pp - p); |
| - F e2 = exp(ww - p); |
| - y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); |
| - |
| - ww = w + pp; |
| - p = max(ww, kk); |
| - e1 = exp(ww - p); |
| - e2 = exp(kk - p); |
| - aa = e1 * aa + e2 * vv; |
| - bb = e1 * bb + e2; |
| - pp = p; |
| - } |
| -} |
| - |
| -template <typename F> |
| -__global__ void kernel_forward_with_state( |
| - const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, |
| - const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s |
| -) { |
| - const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| - const int _b = idx / C; |
| - const int _c = idx % C; |
| - const int _offset_s = _b * C * 3 + _c * 3; |
| - const int _offset = _b * T * C + _c; |
| - |
| - F u = _u[_c]; |
| - F w = _w[_c]; |
| - const F *__restrict__ const k = _k + _offset; |
| - const F *__restrict__ const v = _v + _offset; |
| - F *__restrict__ const y = _y + _offset; |
| - F *__restrict__ const s = _s + _offset_s; |
| - |
| - // aa and bb are running sums divided by exp(pp) (to avoid overflow) |
| - F aa = s[0], bb = s[1], pp = s[2]; |
| - for (int i = 0; i < T; i++) { |
| - const int ii = i * C; |
| - const F kk = k[ii]; |
| - const F vv = v[ii]; |
| - |
| - F ww = u + kk; |
| - F p = max(pp, ww); |
| - F e1 = exp(pp - p); |
| - F e2 = exp(ww - p); |
| - y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); |
| - |
| - ww = w + pp; |
| - p = max(ww, kk); |
| - e1 = exp(ww - p); |
| - e2 = exp(kk - p); |
| - aa = e1 * aa + e2 * vv; |
| - bb = e1 * bb + e2; |
| - pp = p; |
| - } |
| - s[0] = aa; |
| - s[1] = bb; |
| - s[2] = pp; |
| -} |
| - |
| -template <typename F> |
| -__global__ void kernel_backward( |
| - const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, |
| - const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y, |
| - const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, |
| - F *__restrict__ const _gv |
| -) { |
| - const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| - const int _b = idx / C; |
| - const int _c = idx % C; |
| - const int _offset = _b * T * C + _c; |
| - |
| - F u = _u[_c]; |
| - F w = _w[_c]; |
| - const F *__restrict__ const k = _k + _offset; |
| - const F *__restrict__ const v = _v + _offset; |
| - const F *__restrict__ const y = _y + _offset; |
| - const F *__restrict__ const gy = _gy + _offset; |
| - F *__restrict__ const gk = _gk + _offset; |
| - F *__restrict__ const gv = _gv + _offset; |
| - |
| - F q[Tmax], r[Tmax]; |
| - |
| - F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; |
| - for (int i = 0; i < T; i++) { |
| - const int ii = i * C; |
| - const F kk = k[ii]; |
| - const F vv = v[ii]; |
| - const F yy = y[ii]; |
| - |
| - F ww = u + kk; |
| - F p = max(pp, ww); |
| - F e1 = exp(pp - p); |
| - F e2 = exp(ww - p); |
| - const F qq = gy[ii] / (e1 * bb + e2); |
| - gw += (ga - gb * yy) * e1 * qq; |
| - gu += (vv - yy) * e2 * qq; |
| - q[i] = qq; |
| - r[i] = ww - p; |
| - |
| - ww = w + pp; |
| - p = max(ww, kk); |
| - e1 = exp(ww - p); |
| - e2 = exp(kk - p); |
| - ga = e1 * (aa + ga); |
| - gb = e1 * (bb + gb); |
| - aa = e1 * aa + e2 * vv; |
| - bb = e1 * bb + e2; |
| - pp = p; |
| - } |
| - const int _offsetBC = _b * C + _c; |
| - _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() |
| - _gu[_offsetBC] = gu; |
| - |
| - aa = 0, bb = 0, pp = MIN_VALUE; |
| - for (int i = T - 1; i >= 0; i--) { |
| - const int ii = i * C; |
| - const F kk = k[ii]; |
| - const F vv = v[ii]; |
| - const F yy = y[ii]; |
| - const F qq = q[i]; |
| - const F rr = r[i]; |
| - |
| - F e1 = qq * exp(rr); |
| - F e2 = exp(kk + pp); |
| - gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); |
| - gv[ii] = e1 + e2 * aa; |
| - |
| - const F ww = w + pp; |
| - const F www = rr - u - kk; |
| - const F p = max(ww, www); |
| - e1 = exp(ww - p); |
| - e2 = qq * exp(www - p); |
| - aa = e1 * aa + e2; |
| - bb = e1 * bb - e2 * yy; |
| - pp = p; |
| - } |
| -} |
| - |
| -void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { |
| - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| - assert(B * C % threadsPerBlock.x == 0); |
| - dim3 numBlocks(B * C / threadsPerBlock.x); |
| - kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y); |
| -} |
| - |
| -void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) { |
| - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| - assert(B * C % threadsPerBlock.x == 0); |
| - dim3 numBlocks(B * C / threadsPerBlock.x); |
| - kernel_forward_with_state<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, s); |
| -} |
| - |
| -void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { |
| - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| - assert(B * C % threadsPerBlock.x == 0); |
| - dim3 numBlocks(B * C / threadsPerBlock.x); |
| - kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); |
| -} |
| |
| deleted file mode 100644 |
| |
| |
| |
| @@ -1,186 +0,0 @@ |
| -#include <stdio.h> |
| -#include <assert.h> |
| -#include "ATen/ATen.h" |
| -#define MIN_VALUE (-1e38) |
| -typedef at::BFloat16 bf16; |
| - |
| -__global__ void kernel_forward_bf16( |
| - const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, |
| - const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y |
| -) { |
| - const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| - const int _b = idx / C; |
| - const int _c = idx % C; |
| - const int _offset = _b * T * C + _c; |
| - |
| - float u = float(_u[_c]); |
| - float w = _w[_c]; |
| - const bf16 *__restrict__ const k = _k + _offset; |
| - const bf16 *__restrict__ const v = _v + _offset; |
| - bf16 *__restrict__ const y = _y + _offset; |
| - |
| - // aa and bb are running sums divided by exp(pp) (to avoid overflow) |
| - float aa = 0, bb = 0, pp = MIN_VALUE; |
| - for (int i = 0; i < T; i++) { |
| - const int ii = i * C; |
| - const float kk = float(k[ii]); |
| - const float vv = float(v[ii]); |
| - |
| - float ww = u + kk; |
| - float p = max(pp, ww); |
| - float e1 = exp(pp - p); |
| - float e2 = exp(ww - p); |
| - y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2)); |
| - |
| - ww = w + pp; |
| - p = max(ww, kk); |
| - e1 = exp(ww - p); |
| - e2 = exp(kk - p); |
| - aa = e1 * aa + e2 * vv; |
| - bb = e1 * bb + e2; |
| - pp = p; |
| - } |
| -} |
| - |
| -__global__ void kernel_forward_with_state_bf16( |
| - const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, |
| - const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y, |
| - float *__restrict__ const _s |
| -) { |
| - const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| - const int _b = idx / C; |
| - const int _c = idx % C; |
| - const int _offset_s = _b * C * 3 + _c * 3; |
| - const int _offset = _b * T * C + _c; |
| - |
| - float u = float(_u[_c]); |
| - float w = _w[_c]; |
| - const bf16 *__restrict__ const k = _k + _offset; |
| - const bf16 *__restrict__ const v = _v + _offset; |
| - bf16 *__restrict__ const y = _y + _offset; |
| - float *__restrict__ const s = _s + _offset_s; |
| - |
| - // aa and bb are running sums divided by exp(pp) (to avoid overflow) |
| - float aa = s[0], bb = s[1], pp = s[2]; |
| - for (int i = 0; i < T; i++) { |
| - const int ii = i * C; |
| - const float kk = float(k[ii]); |
| - const float vv = float(v[ii]); |
| - |
| - float ww = u + kk; |
| - float p = max(pp, ww); |
| - float e1 = exp(pp - p); |
| - float e2 = exp(ww - p); |
| - y[ii] = bf16(e1 * aa + e2 * vv) / (e1 * bb + e2); |
| - |
| - ww = w + pp; |
| - p = max(ww, kk); |
| - e1 = exp(ww - p); |
| - e2 = exp(kk - p); |
| - aa = e1 * aa + e2 * vv; |
| - bb = e1 * bb + e2; |
| - pp = p; |
| - } |
| - s[0] = aa; |
| - s[1] = bb; |
| - s[2] = pp; |
| -} |
| - |
| -__global__ void kernel_backward_bf16( |
| - const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, |
| - const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, const bf16 *__restrict__ const _y, |
| - const bf16 *__restrict__ const _gy, bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, |
| - bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv |
| -) { |
| - const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| - const int _b = idx / C; |
| - const int _c = idx % C; |
| - const int _offset = _b * T * C + _c; |
| - |
| - float u = float(_u[_c]); |
| - float w = _w[_c]; |
| - const bf16 *__restrict__ const k = _k + _offset; |
| - const bf16 *__restrict__ const v = _v + _offset; |
| - const bf16 *__restrict__ const y = _y + _offset; |
| - const bf16 *__restrict__ const gy = _gy + _offset; |
| - bf16 *__restrict__ const gk = _gk + _offset; |
| - bf16 *__restrict__ const gv = _gv + _offset; |
| - |
| - float q[Tmax], r[Tmax]; |
| - |
| - float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; |
| - for (int i = 0; i < T; i++) { |
| - const int ii = i * C; |
| - const float kk = float(k[ii]); |
| - const float vv = float(v[ii]); |
| - const float yy = float(y[ii]); |
| - |
| - float ww = u + kk; |
| - float p = max(pp, ww); |
| - float e1 = exp(pp - p); |
| - float e2 = exp(ww - p); |
| - const float qq = float(gy[ii]) / (e1 * bb + e2); |
| - gw += (ga - gb * yy) * e1 * qq; |
| - gu += (vv - yy) * e2 * qq; |
| - q[i] = qq; |
| - r[i] = ww - p; |
| - |
| - ww = w + pp; |
| - p = max(ww, kk); |
| - e1 = exp(ww - p); |
| - e2 = exp(kk - p); |
| - ga = e1 * (aa + ga); |
| - gb = e1 * (bb + gb); |
| - aa = e1 * aa + e2 * vv; |
| - bb = e1 * bb + e2; |
| - pp = p; |
| - } |
| - const int _offsetBC = _b * C + _c; |
| - _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward() |
| - _gu[_offsetBC] = bf16(gu); |
| - |
| - aa = 0, bb = 0, pp = MIN_VALUE; |
| - for (int i = T - 1; i >= 0; i--) { |
| - const int ii = i * C; |
| - const float kk = float(k[ii]); |
| - const float vv = float(v[ii]); |
| - const float yy = float(y[ii]); |
| - const float qq = q[i]; |
| - const float rr = r[i]; |
| - |
| - float e1 = qq * exp(rr); |
| - float e2 = exp(kk + pp); |
| - gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb)); |
| - gv[ii] = bf16(e1 + e2 * aa); |
| - |
| - const float ww = w + pp; |
| - const float www = rr - u - kk; |
| - const float p = max(ww, www); |
| - e1 = exp(ww - p); |
| - e2 = qq * exp(www - p); |
| - aa = e1 * aa + e2; |
| - bb = e1 * bb - e2 * yy; |
| - pp = p; |
| - } |
| -} |
| - |
| -void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) { |
| - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| - assert(B * C % threadsPerBlock.x == 0); |
| - dim3 numBlocks(B * C / threadsPerBlock.x); |
| - kernel_forward_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y); |
| -} |
| - |
| -void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s) { |
| - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| - assert(B * C % threadsPerBlock.x == 0); |
| - dim3 numBlocks(B * C / threadsPerBlock.x); |
| - kernel_forward_with_state_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, s); |
| -} |
| - |
| -void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) { |
| - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| - assert(B * C % threadsPerBlock.x == 0); |
| - dim3 numBlocks(B * C / threadsPerBlock.x); |
| - kernel_backward_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); |
| -} |
| |
| deleted file mode 100644 |
| |
| |
| |
| @@ -1,66 +0,0 @@ |
| -#include <torch/extension.h> |
| -#include "ATen/ATen.h" |
| -typedef at::BFloat16 bf16; |
| - |
| -void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); |
| -void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y); |
| -void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s); |
| -void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s); |
| -void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); |
| -void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv); |
| - |
| -void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { |
| - const int B = k.size(0); |
| - const int T = k.size(1); |
| - const int C = k.size(2); |
| - cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>()); |
| -} |
| -void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { |
| - const int B = k.size(0); |
| - const int T = k.size(1); |
| - const int C = k.size(2); |
| - cuda_forward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>()); |
| -} |
| -void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) { |
| - const int B = k.size(0); |
| - const int T = k.size(1); |
| - const int C = k.size(2); |
| - cuda_forward_with_state(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), s.data_ptr<float>()); |
| -} |
| -void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) { |
| - const int B = k.size(0); |
| - const int T = k.size(1); |
| - const int C = k.size(2); |
| - cuda_forward_with_state_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), s.data_ptr<float>()); |
| -} |
| -void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { |
| - const int B = k.size(0); |
| - const int T = k.size(1); |
| - const int C = k.size(2); |
| - cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>()); |
| -} |
| -void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { |
| - const int B = k.size(0); |
| - const int T = k.size(1); |
| - const int C = k.size(2); |
| - cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), |
| - gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>()); |
| -} |
| - |
| -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| - m.def("forward", &forward, "wkv forward"); |
| - m.def("forward_bf16", &forward_bf16, "wkv forward bf16"); |
| - m.def("forward_with_state", &forward_with_state, "wkv forward with state"); |
| - m.def("forward_with_state_bf16", &forward_with_state_bf16, "wkv forward with state bf16"); |
| - m.def("backward", &backward, "wkv backward"); |
| - m.def("backward_bf16", &backward_bf16, "wkv backward bf16"); |
| -} |
| - |
| -TORCH_LIBRARY(wkv, m) { |
| - m.def("forward", forward); |
| - m.def("forward_bf16", forward_bf16); |
| - m.def("forward_with_state", forward_with_state); |
| - m.def("forward_with_state_bf16", forward_with_state_bf16); |
| - m.def("backward", backward); |
| - m.def("backward_bf16", backward_bf16); |
| -} |
| |
| |
| |
| |
| @@ -17,7 +17,6 @@ |
| |
| import math |
| from dataclasses import dataclass |
| -from pathlib import Path |
| from typing import Optional, Union |
| |
| import torch |
| @@ -30,6 +29,7 @@ |
| ModelOutput, |
| auto_docstring, |
| is_bitsandbytes_available, |
| + is_kernels_available, |
| is_ninja_available, |
| is_torch_cuda_available, |
| logging, |
| @@ -44,34 +44,13 @@ |
| |
| |
| def load_wkv_cuda_kernel(context_length): |
| - from torch.utils.cpp_extension import load as load_kernel |
| - |
| global rwkv_cuda_kernel |
| + if not is_kernels_available(): |
| + raise ImportError("kernels is not installed, please install it with `pip install kernels`") |
| + |
| + from kernels import get_kernel |
| |
| - kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv" |
| - cuda_kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu", "wkv_cuda_bf16.cu"]] |
| - |
| - # Only load the kernel if it's not been loaded yet or if we changed the context length |
| - if rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == context_length: |
| - return |
| - |
| - logger.info(f"Loading CUDA kernel for RWKV at context length of {context_length}.") |
| - |
| - flags = [ |
| - "-res-usage", |
| - "--maxrregcount 60", |
| - "--use_fast_math", |
| - "-O3", |
| - "-Xptxas -O3", |
| - "--extra-device-vectorization", |
| - f"-DTmax={context_length}", |
| - ] |
| - rwkv_cuda_kernel = load_kernel( |
| - name=f"wkv_{context_length}", |
| - sources=cuda_kernel_files, |
| - verbose=(logging.get_verbosity() == logging.DEBUG), |
| - extra_cuda_cflags=flags, |
| - ) |
| + rwkv_cuda_kernel = get_kernel("kernels-community/rwkv") |
| rwkv_cuda_kernel.max_seq_length = context_length |
| |
| |
|
|