Build uploaded using `kernels`.
Browse files- .gitattributes +5 -0
- build/torch210-cxx11-cu128-x86_64-linux/__init__.py +5 -4
- build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu128-x86_64-linux/{_sgl_flash_attn3_cuda_b8d1001.abi3.so → _sgl_flash_attn3_cuda_1988a8e.abi3.so} +1 -1
- build/torch210-cxx11-cu128-x86_64-linux/metadata.json +9 -2
- build/torch210-cxx11-cu130-x86_64-linux/__init__.py +5 -4
- build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu130-x86_64-linux/{_sgl_flash_attn3_cuda_b8d1001.abi3.so → _sgl_flash_attn3_cuda_1988a8e.abi3.so} +1 -1
- build/torch210-cxx11-cu130-x86_64-linux/metadata.json +9 -2
- build/torch29-cxx11-cu128-x86_64-linux/__init__.py +5 -4
- build/torch29-cxx11-cu128-x86_64-linux/_ops.py +3 -3
- build/torch29-cxx11-cu128-x86_64-linux/{_sgl_flash_attn3_cuda_b8d1001.abi3.so → _sgl_flash_attn3_cuda_1988a8e.abi3.so} +1 -1
- build/torch29-cxx11-cu128-x86_64-linux/metadata.json +9 -2
- build/torch29-cxx11-cu129-x86_64-linux/__init__.py +208 -0
- build/torch29-cxx11-cu129-x86_64-linux/_ops.py +9 -0
- build/torch29-cxx11-cu129-x86_64-linux/_sgl_flash_attn3_cuda_1988a8e.abi3.so +3 -0
- build/torch29-cxx11-cu129-x86_64-linux/metadata.json +12 -0
- build/torch29-cxx11-cu129-x86_64-linux/sgl_flash_attn3/__init__.py +26 -0
- build/torch29-cxx11-cu130-x86_64-linux/__init__.py +5 -4
- build/torch29-cxx11-cu130-x86_64-linux/_ops.py +3 -3
- build/torch29-cxx11-cu130-x86_64-linux/{_sgl_flash_attn3_cuda_b8d1001.abi3.so → _sgl_flash_attn3_cuda_1988a8e.abi3.so} +1 -1
- build/torch29-cxx11-cu130-x86_64-linux/metadata.json +9 -2
.gitattributes
CHANGED
|
@@ -37,3 +37,8 @@ build/torch210-cxx11-cu128-x86_64-linux/_sgl_flash_attn3_cuda_b8d1001.abi3.so fi
|
|
| 37 |
build/torch210-cxx11-cu130-x86_64-linux/_sgl_flash_attn3_cuda_b8d1001.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 38 |
build/torch29-cxx11-cu128-x86_64-linux/_sgl_flash_attn3_cuda_b8d1001.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 39 |
build/torch29-cxx11-cu130-x86_64-linux/_sgl_flash_attn3_cuda_b8d1001.abi3.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
build/torch210-cxx11-cu130-x86_64-linux/_sgl_flash_attn3_cuda_b8d1001.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 38 |
build/torch29-cxx11-cu128-x86_64-linux/_sgl_flash_attn3_cuda_b8d1001.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 39 |
build/torch29-cxx11-cu130-x86_64-linux/_sgl_flash_attn3_cuda_b8d1001.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
build/torch210-cxx11-cu128-x86_64-linux/_sgl_flash_attn3_cuda_1988a8e.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
build/torch210-cxx11-cu130-x86_64-linux/_sgl_flash_attn3_cuda_1988a8e.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
build/torch29-cxx11-cu128-x86_64-linux/_sgl_flash_attn3_cuda_1988a8e.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
build/torch29-cxx11-cu129-x86_64-linux/_sgl_flash_attn3_cuda_1988a8e.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
build/torch29-cxx11-cu130-x86_64-linux/_sgl_flash_attn3_cuda_1988a8e.abi3.so filter=lfs diff=lfs merge=lfs -text
|
build/torch210-cxx11-cu128-x86_64-linux/__init__.py
CHANGED
|
@@ -86,7 +86,7 @@ def flash_attn_with_kvcache(
|
|
| 86 |
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
| 87 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 88 |
|
| 89 |
-
out, softmax_lse, *
|
| 90 |
q,
|
| 91 |
k_cache,
|
| 92 |
v_cache,
|
|
@@ -123,7 +123,8 @@ def flash_attn_with_kvcache(
|
|
| 123 |
sm_margin,
|
| 124 |
sinks,
|
| 125 |
)
|
| 126 |
-
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def flash_attn_varlen_func(
|
|
@@ -166,7 +167,7 @@ def flash_attn_varlen_func(
|
|
| 166 |
)
|
| 167 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 168 |
|
| 169 |
-
out, softmax_lse, *
|
| 170 |
q,
|
| 171 |
k,
|
| 172 |
v,
|
|
@@ -204,4 +205,4 @@ def flash_attn_varlen_func(
|
|
| 204 |
sinks,
|
| 205 |
)
|
| 206 |
|
| 207 |
-
return (out, softmax_lse
|
|
|
|
| 86 |
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
| 87 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 88 |
|
| 89 |
+
out, softmax_lse, *_ = ops.fwd(
|
| 90 |
q,
|
| 91 |
k_cache,
|
| 92 |
v_cache,
|
|
|
|
| 123 |
sm_margin,
|
| 124 |
sinks,
|
| 125 |
)
|
| 126 |
+
|
| 127 |
+
return (out, softmax_lse) if return_softmax_lse else out
|
| 128 |
|
| 129 |
|
| 130 |
def flash_attn_varlen_func(
|
|
|
|
| 167 |
)
|
| 168 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 169 |
|
| 170 |
+
out, softmax_lse, *_ = ops.fwd(
|
| 171 |
q,
|
| 172 |
k,
|
| 173 |
v,
|
|
|
|
| 205 |
sinks,
|
| 206 |
)
|
| 207 |
|
| 208 |
+
return (out, softmax_lse) if return_softmax_lse else out
|
build/torch210-cxx11-cu128-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _sgl_flash_attn3_cuda_1988a8e
|
| 3 |
+
ops = torch.ops._sgl_flash_attn3_cuda_1988a8e
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_sgl_flash_attn3_cuda_1988a8e::{op_name}"
|
build/torch210-cxx11-cu128-x86_64-linux/{_sgl_flash_attn3_cuda_b8d1001.abi3.so → _sgl_flash_attn3_cuda_1988a8e.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 916761672
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6c2c92bea67e95646f7ede11a3a8a683c95467b988344155708fca7329358fc
|
| 3 |
size 916761672
|
build/torch210-cxx11-cu128-x86_64-linux/metadata.json
CHANGED
|
@@ -1,5 +1,12 @@
|
|
| 1 |
{
|
| 2 |
"version": 1,
|
| 3 |
"license": "BSD-3-Clause",
|
| 4 |
-
"python-depends": []
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
{
|
| 2 |
"version": 1,
|
| 3 |
"license": "BSD-3-Clause",
|
| 4 |
+
"python-depends": [],
|
| 5 |
+
"backend": {
|
| 6 |
+
"type": "cuda",
|
| 7 |
+
"archs": [
|
| 8 |
+
"8.0",
|
| 9 |
+
"9.0a"
|
| 10 |
+
]
|
| 11 |
+
}
|
| 12 |
+
}
|
build/torch210-cxx11-cu130-x86_64-linux/__init__.py
CHANGED
|
@@ -86,7 +86,7 @@ def flash_attn_with_kvcache(
|
|
| 86 |
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
| 87 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 88 |
|
| 89 |
-
out, softmax_lse, *
|
| 90 |
q,
|
| 91 |
k_cache,
|
| 92 |
v_cache,
|
|
@@ -123,7 +123,8 @@ def flash_attn_with_kvcache(
|
|
| 123 |
sm_margin,
|
| 124 |
sinks,
|
| 125 |
)
|
| 126 |
-
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def flash_attn_varlen_func(
|
|
@@ -166,7 +167,7 @@ def flash_attn_varlen_func(
|
|
| 166 |
)
|
| 167 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 168 |
|
| 169 |
-
out, softmax_lse, *
|
| 170 |
q,
|
| 171 |
k,
|
| 172 |
v,
|
|
@@ -204,4 +205,4 @@ def flash_attn_varlen_func(
|
|
| 204 |
sinks,
|
| 205 |
)
|
| 206 |
|
| 207 |
-
return (out, softmax_lse
|
|
|
|
| 86 |
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
| 87 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 88 |
|
| 89 |
+
out, softmax_lse, *_ = ops.fwd(
|
| 90 |
q,
|
| 91 |
k_cache,
|
| 92 |
v_cache,
|
|
|
|
| 123 |
sm_margin,
|
| 124 |
sinks,
|
| 125 |
)
|
| 126 |
+
|
| 127 |
+
return (out, softmax_lse) if return_softmax_lse else out
|
| 128 |
|
| 129 |
|
| 130 |
def flash_attn_varlen_func(
|
|
|
|
| 167 |
)
|
| 168 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 169 |
|
| 170 |
+
out, softmax_lse, *_ = ops.fwd(
|
| 171 |
q,
|
| 172 |
k,
|
| 173 |
v,
|
|
|
|
| 205 |
sinks,
|
| 206 |
)
|
| 207 |
|
| 208 |
+
return (out, softmax_lse) if return_softmax_lse else out
|
build/torch210-cxx11-cu130-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _sgl_flash_attn3_cuda_1988a8e
|
| 3 |
+
ops = torch.ops._sgl_flash_attn3_cuda_1988a8e
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_sgl_flash_attn3_cuda_1988a8e::{op_name}"
|
build/torch210-cxx11-cu130-x86_64-linux/{_sgl_flash_attn3_cuda_b8d1001.abi3.so → _sgl_flash_attn3_cuda_1988a8e.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 917921992
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:763d2698ae560844d581a83f0ce4e8da584849f1beeb2e8ebf008fdc3db10c7a
|
| 3 |
size 917921992
|
build/torch210-cxx11-cu130-x86_64-linux/metadata.json
CHANGED
|
@@ -1,5 +1,12 @@
|
|
| 1 |
{
|
| 2 |
"version": 1,
|
| 3 |
"license": "BSD-3-Clause",
|
| 4 |
-
"python-depends": []
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
{
|
| 2 |
"version": 1,
|
| 3 |
"license": "BSD-3-Clause",
|
| 4 |
+
"python-depends": [],
|
| 5 |
+
"backend": {
|
| 6 |
+
"type": "cuda",
|
| 7 |
+
"archs": [
|
| 8 |
+
"8.0",
|
| 9 |
+
"9.0a"
|
| 10 |
+
]
|
| 11 |
+
}
|
| 12 |
+
}
|
build/torch29-cxx11-cu128-x86_64-linux/__init__.py
CHANGED
|
@@ -86,7 +86,7 @@ def flash_attn_with_kvcache(
|
|
| 86 |
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
| 87 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 88 |
|
| 89 |
-
out, softmax_lse, *
|
| 90 |
q,
|
| 91 |
k_cache,
|
| 92 |
v_cache,
|
|
@@ -123,7 +123,8 @@ def flash_attn_with_kvcache(
|
|
| 123 |
sm_margin,
|
| 124 |
sinks,
|
| 125 |
)
|
| 126 |
-
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def flash_attn_varlen_func(
|
|
@@ -166,7 +167,7 @@ def flash_attn_varlen_func(
|
|
| 166 |
)
|
| 167 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 168 |
|
| 169 |
-
out, softmax_lse, *
|
| 170 |
q,
|
| 171 |
k,
|
| 172 |
v,
|
|
@@ -204,4 +205,4 @@ def flash_attn_varlen_func(
|
|
| 204 |
sinks,
|
| 205 |
)
|
| 206 |
|
| 207 |
-
return (out, softmax_lse
|
|
|
|
| 86 |
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
| 87 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 88 |
|
| 89 |
+
out, softmax_lse, *_ = ops.fwd(
|
| 90 |
q,
|
| 91 |
k_cache,
|
| 92 |
v_cache,
|
|
|
|
| 123 |
sm_margin,
|
| 124 |
sinks,
|
| 125 |
)
|
| 126 |
+
|
| 127 |
+
return (out, softmax_lse) if return_softmax_lse else out
|
| 128 |
|
| 129 |
|
| 130 |
def flash_attn_varlen_func(
|
|
|
|
| 167 |
)
|
| 168 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 169 |
|
| 170 |
+
out, softmax_lse, *_ = ops.fwd(
|
| 171 |
q,
|
| 172 |
k,
|
| 173 |
v,
|
|
|
|
| 205 |
sinks,
|
| 206 |
)
|
| 207 |
|
| 208 |
+
return (out, softmax_lse) if return_softmax_lse else out
|
build/torch29-cxx11-cu128-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _sgl_flash_attn3_cuda_1988a8e
|
| 3 |
+
ops = torch.ops._sgl_flash_attn3_cuda_1988a8e
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_sgl_flash_attn3_cuda_1988a8e::{op_name}"
|
build/torch29-cxx11-cu128-x86_64-linux/{_sgl_flash_attn3_cuda_b8d1001.abi3.so → _sgl_flash_attn3_cuda_1988a8e.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 916755168
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:918b5b761bc5150da21a2af42ea8c8892fc51ef1200a4febaa1a05d706400ddb
|
| 3 |
size 916755168
|
build/torch29-cxx11-cu128-x86_64-linux/metadata.json
CHANGED
|
@@ -1,5 +1,12 @@
|
|
| 1 |
{
|
| 2 |
"version": 1,
|
| 3 |
"license": "BSD-3-Clause",
|
| 4 |
-
"python-depends": []
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
{
|
| 2 |
"version": 1,
|
| 3 |
"license": "BSD-3-Clause",
|
| 4 |
+
"python-depends": [],
|
| 5 |
+
"backend": {
|
| 6 |
+
"type": "cuda",
|
| 7 |
+
"archs": [
|
| 8 |
+
"8.0",
|
| 9 |
+
"9.0a"
|
| 10 |
+
]
|
| 11 |
+
}
|
| 12 |
+
}
|
build/torch29-cxx11-cu129-x86_64-linux/__init__.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import lru_cache
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ._ops import ops
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@lru_cache(maxsize=1)
|
| 10 |
+
def is_fa3_supported(device=None) -> bool:
|
| 11 |
+
return (torch.version.cuda >= "12.3") and (
|
| 12 |
+
torch.cuda.get_device_capability(device)[0] == 9
|
| 13 |
+
or torch.cuda.get_device_capability(device)[0] == 8
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def maybe_contiguous(x):
|
| 18 |
+
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def flash_attn_with_kvcache(
|
| 22 |
+
q,
|
| 23 |
+
k_cache,
|
| 24 |
+
v_cache,
|
| 25 |
+
k=None,
|
| 26 |
+
v=None,
|
| 27 |
+
qv=None,
|
| 28 |
+
rotary_cos=None,
|
| 29 |
+
rotary_sin=None,
|
| 30 |
+
cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
|
| 31 |
+
cache_batch_idx: Optional[torch.Tensor] = None,
|
| 32 |
+
cache_leftpad: Optional[torch.Tensor] = None,
|
| 33 |
+
page_table: Optional[torch.Tensor] = None,
|
| 34 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
| 35 |
+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
| 36 |
+
max_seqlen_q: Optional[int] = None,
|
| 37 |
+
rotary_seqlens: Optional[torch.Tensor] = None,
|
| 38 |
+
q_descale: Optional[torch.Tensor] = None,
|
| 39 |
+
k_descale: Optional[torch.Tensor] = None,
|
| 40 |
+
v_descale: Optional[torch.Tensor] = None,
|
| 41 |
+
softmax_scale=None,
|
| 42 |
+
causal=False,
|
| 43 |
+
window_size=(-1, -1),
|
| 44 |
+
attention_chunk: Optional[int] = None,
|
| 45 |
+
softcap=0.0,
|
| 46 |
+
rotary_interleaved=True,
|
| 47 |
+
scheduler_metadata=None,
|
| 48 |
+
num_splits=0,
|
| 49 |
+
pack_gqa=None,
|
| 50 |
+
sm_margin=0,
|
| 51 |
+
return_softmax_lse=False,
|
| 52 |
+
sinks=None,
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
FA3 flash_attn_with_kvcache: forward-only attention with paged KV cache,
|
| 56 |
+
optional rotary embedding, sliding window, softcapping, GQA/MQA, and
|
| 57 |
+
inplace KV cache update.
|
| 58 |
+
|
| 59 |
+
See sgl-kernel documentation for full argument descriptions.
|
| 60 |
+
"""
|
| 61 |
+
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
| 62 |
+
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
| 63 |
+
if softmax_scale is None:
|
| 64 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
|
| 65 |
+
-0.5
|
| 66 |
+
)
|
| 67 |
+
if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
| 68 |
+
cache_seqlens = torch.full(
|
| 69 |
+
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
| 70 |
+
)
|
| 71 |
+
cache_seqlens = maybe_contiguous(cache_seqlens)
|
| 72 |
+
|
| 73 |
+
q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
|
| 74 |
+
v_cache = (
|
| 75 |
+
v_cache.contiguous()
|
| 76 |
+
if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1
|
| 77 |
+
else v_cache
|
| 78 |
+
)
|
| 79 |
+
cu_seqlens_q, cu_seqlens_k_new = [
|
| 80 |
+
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)
|
| 81 |
+
]
|
| 82 |
+
page_table, cache_batch_idx, cache_leftpad = [
|
| 83 |
+
maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
|
| 84 |
+
]
|
| 85 |
+
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
|
| 86 |
+
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
| 87 |
+
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 88 |
+
|
| 89 |
+
out, softmax_lse, *_ = ops.fwd(
|
| 90 |
+
q,
|
| 91 |
+
k_cache,
|
| 92 |
+
v_cache,
|
| 93 |
+
k,
|
| 94 |
+
v,
|
| 95 |
+
qv,
|
| 96 |
+
None, # out
|
| 97 |
+
cu_seqlens_q,
|
| 98 |
+
None, # cu_seqlens_k
|
| 99 |
+
cu_seqlens_k_new,
|
| 100 |
+
None, # seqused_q
|
| 101 |
+
cache_seqlens,
|
| 102 |
+
max_seqlen_q,
|
| 103 |
+
None, # max_seqlen_k
|
| 104 |
+
page_table,
|
| 105 |
+
cache_batch_idx,
|
| 106 |
+
cache_leftpad,
|
| 107 |
+
rotary_cos,
|
| 108 |
+
rotary_sin,
|
| 109 |
+
rotary_seqlens,
|
| 110 |
+
q_descale,
|
| 111 |
+
k_descale,
|
| 112 |
+
v_descale,
|
| 113 |
+
softmax_scale,
|
| 114 |
+
causal,
|
| 115 |
+
window_size[0],
|
| 116 |
+
window_size[1],
|
| 117 |
+
attention_chunk,
|
| 118 |
+
softcap,
|
| 119 |
+
rotary_interleaved,
|
| 120 |
+
scheduler_metadata,
|
| 121 |
+
num_splits,
|
| 122 |
+
pack_gqa,
|
| 123 |
+
sm_margin,
|
| 124 |
+
sinks,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return (out, softmax_lse) if return_softmax_lse else out
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def flash_attn_varlen_func(
|
| 131 |
+
q,
|
| 132 |
+
k,
|
| 133 |
+
v,
|
| 134 |
+
cu_seqlens_q,
|
| 135 |
+
cu_seqlens_k,
|
| 136 |
+
max_seqlen_q=None,
|
| 137 |
+
max_seqlen_k=None,
|
| 138 |
+
seqused_q=None,
|
| 139 |
+
seqused_k=None,
|
| 140 |
+
page_table=None,
|
| 141 |
+
softmax_scale=None,
|
| 142 |
+
causal=False,
|
| 143 |
+
qv=None,
|
| 144 |
+
q_descale=None,
|
| 145 |
+
k_descale=None,
|
| 146 |
+
v_descale=None,
|
| 147 |
+
window_size=(-1, -1),
|
| 148 |
+
attention_chunk=0,
|
| 149 |
+
softcap=0.0,
|
| 150 |
+
num_splits=1,
|
| 151 |
+
pack_gqa=None,
|
| 152 |
+
sm_margin=0,
|
| 153 |
+
return_softmax_lse=False,
|
| 154 |
+
sinks=None,
|
| 155 |
+
):
|
| 156 |
+
if not is_fa3_supported():
|
| 157 |
+
raise NotImplementedError(
|
| 158 |
+
"sgl_flash_attn3 is only supported on sm80 and above with CUDA >= 12.3"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if max_seqlen_q is None or max_seqlen_k is None:
|
| 162 |
+
raise ValueError("max_seqlen_q and max_seqlen_k are required for FA3")
|
| 163 |
+
|
| 164 |
+
if softmax_scale is None:
|
| 165 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
|
| 166 |
+
-0.5
|
| 167 |
+
)
|
| 168 |
+
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 169 |
+
|
| 170 |
+
out, softmax_lse, *_ = ops.fwd(
|
| 171 |
+
q,
|
| 172 |
+
k,
|
| 173 |
+
v,
|
| 174 |
+
None, # k_new
|
| 175 |
+
None, # v_new
|
| 176 |
+
qv,
|
| 177 |
+
None, # out
|
| 178 |
+
cu_seqlens_q,
|
| 179 |
+
cu_seqlens_k,
|
| 180 |
+
None, # cu_seqlens_k_new
|
| 181 |
+
seqused_q,
|
| 182 |
+
seqused_k,
|
| 183 |
+
max_seqlen_q,
|
| 184 |
+
max_seqlen_k,
|
| 185 |
+
None, # page_table
|
| 186 |
+
None, # kv_batch_idx
|
| 187 |
+
None, # leftpad_k
|
| 188 |
+
None, # rotary cos
|
| 189 |
+
None, # rotary sin
|
| 190 |
+
None, # seqlens_rotary
|
| 191 |
+
q_descale,
|
| 192 |
+
k_descale,
|
| 193 |
+
v_descale,
|
| 194 |
+
softmax_scale,
|
| 195 |
+
causal,
|
| 196 |
+
window_size[0],
|
| 197 |
+
window_size[1],
|
| 198 |
+
attention_chunk,
|
| 199 |
+
softcap,
|
| 200 |
+
False, # is_rotary_interleaved
|
| 201 |
+
None, # scheduler_metadata
|
| 202 |
+
num_splits,
|
| 203 |
+
pack_gqa,
|
| 204 |
+
sm_margin,
|
| 205 |
+
sinks,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
return (out, softmax_lse) if return_softmax_lse else out
|
build/torch29-cxx11-cu129-x86_64-linux/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _sgl_flash_attn3_cuda_1988a8e
|
| 3 |
+
ops = torch.ops._sgl_flash_attn3_cuda_1988a8e
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_sgl_flash_attn3_cuda_1988a8e::{op_name}"
|
build/torch29-cxx11-cu129-x86_64-linux/_sgl_flash_attn3_cuda_1988a8e.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f40e01d849dfe7c460d1ff7c3a53396ce293024b241002cee8842ac8c6cb01be
|
| 3 |
+
size 921853920
|
build/torch29-cxx11-cu129-x86_64-linux/metadata.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": 1,
|
| 3 |
+
"license": "BSD-3-Clause",
|
| 4 |
+
"python-depends": [],
|
| 5 |
+
"backend": {
|
| 6 |
+
"type": "cuda",
|
| 7 |
+
"archs": [
|
| 8 |
+
"8.0",
|
| 9 |
+
"9.0a"
|
| 10 |
+
]
|
| 11 |
+
}
|
| 12 |
+
}
|
build/torch29-cxx11-cu129-x86_64-linux/sgl_flash_attn3/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch29-cxx11-cu130-x86_64-linux/__init__.py
CHANGED
|
@@ -86,7 +86,7 @@ def flash_attn_with_kvcache(
|
|
| 86 |
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
| 87 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 88 |
|
| 89 |
-
out, softmax_lse, *
|
| 90 |
q,
|
| 91 |
k_cache,
|
| 92 |
v_cache,
|
|
@@ -123,7 +123,8 @@ def flash_attn_with_kvcache(
|
|
| 123 |
sm_margin,
|
| 124 |
sinks,
|
| 125 |
)
|
| 126 |
-
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def flash_attn_varlen_func(
|
|
@@ -166,7 +167,7 @@ def flash_attn_varlen_func(
|
|
| 166 |
)
|
| 167 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 168 |
|
| 169 |
-
out, softmax_lse, *
|
| 170 |
q,
|
| 171 |
k,
|
| 172 |
v,
|
|
@@ -204,4 +205,4 @@ def flash_attn_varlen_func(
|
|
| 204 |
sinks,
|
| 205 |
)
|
| 206 |
|
| 207 |
-
return (out, softmax_lse
|
|
|
|
| 86 |
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
| 87 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 88 |
|
| 89 |
+
out, softmax_lse, *_ = ops.fwd(
|
| 90 |
q,
|
| 91 |
k_cache,
|
| 92 |
v_cache,
|
|
|
|
| 123 |
sm_margin,
|
| 124 |
sinks,
|
| 125 |
)
|
| 126 |
+
|
| 127 |
+
return (out, softmax_lse) if return_softmax_lse else out
|
| 128 |
|
| 129 |
|
| 130 |
def flash_attn_varlen_func(
|
|
|
|
| 167 |
)
|
| 168 |
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
|
| 169 |
|
| 170 |
+
out, softmax_lse, *_ = ops.fwd(
|
| 171 |
q,
|
| 172 |
k,
|
| 173 |
v,
|
|
|
|
| 205 |
sinks,
|
| 206 |
)
|
| 207 |
|
| 208 |
+
return (out, softmax_lse) if return_softmax_lse else out
|
build/torch29-cxx11-cu130-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _sgl_flash_attn3_cuda_1988a8e
|
| 3 |
+
ops = torch.ops._sgl_flash_attn3_cuda_1988a8e
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_sgl_flash_attn3_cuda_1988a8e::{op_name}"
|
build/torch29-cxx11-cu130-x86_64-linux/{_sgl_flash_attn3_cuda_b8d1001.abi3.so → _sgl_flash_attn3_cuda_1988a8e.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 917907296
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c23184d7250226dad7f0d023593f9c45b9db57c9f1daa6e08e216a32f78df8fa
|
| 3 |
size 917907296
|
build/torch29-cxx11-cu130-x86_64-linux/metadata.json
CHANGED
|
@@ -1,5 +1,12 @@
|
|
| 1 |
{
|
| 2 |
"version": 1,
|
| 3 |
"license": "BSD-3-Clause",
|
| 4 |
-
"python-depends": []
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
{
|
| 2 |
"version": 1,
|
| 3 |
"license": "BSD-3-Clause",
|
| 4 |
+
"python-depends": [],
|
| 5 |
+
"backend": {
|
| 6 |
+
"type": "cuda",
|
| 7 |
+
"archs": [
|
| 8 |
+
"8.0",
|
| 9 |
+
"9.0a"
|
| 10 |
+
]
|
| 11 |
+
}
|
| 12 |
+
}
|