Mayo commited on
perf: FLUX.2 improvements
Browse files- .cargo/config.toml +1 -1
- Cargo.lock +44 -24
- Cargo.toml +7 -3
- koharu-ml/Cargo.toml +6 -0
- koharu-ml/src/flux2_klein/mod.rs +69 -7
- koharu-ml/src/flux2_klein/transformer.rs +150 -40
- koharu-ml/src/flux2_klein/vae.rs +18 -25
- koharu-runtime/src/cuda.rs +43 -2
- koharu/tauri.windows.conf.json +2 -2
.cargo/config.toml
CHANGED
|
@@ -5,4 +5,4 @@ LLAMA_CPP_TAG = "b8665"
|
|
| 5 |
# CUDA 13.0 requires C++17
|
| 6 |
NVCC_PREPEND_FLAGS = "-std=c++17"
|
| 7 |
# override nvidia-smi compute capability
|
| 8 |
-
CUDA_COMPUTE_CAP = "
|
|
|
|
| 5 |
# CUDA 13.0 requires C++17
|
| 6 |
NVCC_PREPEND_FLAGS = "-std=c++17"
|
| 7 |
# override nvidia-smi compute capability
|
| 8 |
+
CUDA_COMPUTE_CAP = "80"
|
Cargo.lock
CHANGED
|
@@ -828,9 +828,9 @@ dependencies = [
|
|
| 828 |
|
| 829 |
[[package]]
|
| 830 |
name = "blake3"
|
| 831 |
-
version = "1.8.
|
| 832 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 833 |
-
checksum = "
|
| 834 |
dependencies = [
|
| 835 |
"arrayref",
|
| 836 |
"arrayvec",
|
|
@@ -1007,7 +1007,7 @@ dependencies = [
|
|
| 1007 |
[[package]]
|
| 1008 |
name = "candle-core"
|
| 1009 |
version = "0.9.2"
|
| 1010 |
-
source = "git+https://github.com/mayocream/candle?branch=
|
| 1011 |
dependencies = [
|
| 1012 |
"byteorder",
|
| 1013 |
"candle-kernels",
|
|
@@ -1033,10 +1033,29 @@ dependencies = [
|
|
| 1033 |
"zip 7.2.0",
|
| 1034 |
]
|
| 1035 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1036 |
[[package]]
|
| 1037 |
name = "candle-kernels"
|
| 1038 |
version = "0.9.2"
|
| 1039 |
-
source = "git+https://github.com/mayocream/candle?branch=
|
| 1040 |
dependencies = [
|
| 1041 |
"bindgen_cuda",
|
| 1042 |
]
|
|
@@ -1044,7 +1063,7 @@ dependencies = [
|
|
| 1044 |
[[package]]
|
| 1045 |
name = "candle-metal-kernels"
|
| 1046 |
version = "0.9.2"
|
| 1047 |
-
source = "git+https://github.com/mayocream/candle?branch=
|
| 1048 |
dependencies = [
|
| 1049 |
"half",
|
| 1050 |
"objc2",
|
|
@@ -1058,7 +1077,7 @@ dependencies = [
|
|
| 1058 |
[[package]]
|
| 1059 |
name = "candle-nn"
|
| 1060 |
version = "0.9.2"
|
| 1061 |
-
source = "git+https://github.com/mayocream/candle?branch=
|
| 1062 |
dependencies = [
|
| 1063 |
"candle-core",
|
| 1064 |
"candle-metal-kernels",
|
|
@@ -1075,7 +1094,7 @@ dependencies = [
|
|
| 1075 |
[[package]]
|
| 1076 |
name = "candle-transformers"
|
| 1077 |
version = "0.9.2"
|
| 1078 |
-
source = "git+https://github.com/mayocream/candle?branch=
|
| 1079 |
dependencies = [
|
| 1080 |
"byteorder",
|
| 1081 |
"candle-core",
|
|
@@ -1093,7 +1112,7 @@ dependencies = [
|
|
| 1093 |
[[package]]
|
| 1094 |
name = "candle-ug"
|
| 1095 |
version = "0.9.2"
|
| 1096 |
-
source = "git+https://github.com/mayocream/candle?branch=
|
| 1097 |
dependencies = [
|
| 1098 |
"ug",
|
| 1099 |
"ug-cuda",
|
|
@@ -1150,9 +1169,9 @@ dependencies = [
|
|
| 1150 |
|
| 1151 |
[[package]]
|
| 1152 |
name = "cc"
|
| 1153 |
-
version = "1.2.
|
| 1154 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 1155 |
-
checksum = "
|
| 1156 |
dependencies = [
|
| 1157 |
"find-msvc-tools",
|
| 1158 |
"jobserver",
|
|
@@ -1790,9 +1809,9 @@ dependencies = [
|
|
| 1790 |
|
| 1791 |
[[package]]
|
| 1792 |
name = "data-encoding"
|
| 1793 |
-
version = "2.
|
| 1794 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 1795 |
-
checksum = "
|
| 1796 |
|
| 1797 |
[[package]]
|
| 1798 |
name = "debugid"
|
|
@@ -2078,14 +2097,14 @@ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
|
| 2078 |
|
| 2079 |
[[package]]
|
| 2080 |
name = "embed-resource"
|
| 2081 |
-
version = "3.0.
|
| 2082 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 2083 |
-
checksum = "
|
| 2084 |
dependencies = [
|
| 2085 |
"cc",
|
| 2086 |
"memchr",
|
| 2087 |
"rustc_version",
|
| 2088 |
-
"toml
|
| 2089 |
"vswhom",
|
| 2090 |
"winreg 0.55.0",
|
| 2091 |
]
|
|
@@ -4677,6 +4696,7 @@ version = "0.49.0"
|
|
| 4677 |
dependencies = [
|
| 4678 |
"anyhow",
|
| 4679 |
"candle-core",
|
|
|
|
| 4680 |
"candle-nn",
|
| 4681 |
"candle-transformers",
|
| 4682 |
"clap",
|
|
@@ -4866,9 +4886,9 @@ dependencies = [
|
|
| 4866 |
|
| 4867 |
[[package]]
|
| 4868 |
name = "libc"
|
| 4869 |
-
version = "0.2.
|
| 4870 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 4871 |
-
checksum = "
|
| 4872 |
|
| 4873 |
[[package]]
|
| 4874 |
name = "libfuzzer-sys"
|
|
@@ -6080,9 +6100,9 @@ checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec"
|
|
| 6080 |
|
| 6081 |
[[package]]
|
| 6082 |
name = "pastey"
|
| 6083 |
-
version = "0.2.
|
| 6084 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 6085 |
-
checksum = "
|
| 6086 |
|
| 6087 |
[[package]]
|
| 6088 |
name = "pathdiff"
|
|
@@ -7354,7 +7374,7 @@ dependencies = [
|
|
| 7354 |
"http 1.4.0",
|
| 7355 |
"http-body",
|
| 7356 |
"http-body-util",
|
| 7357 |
-
"pastey 0.2.
|
| 7358 |
"pin-project-lite",
|
| 7359 |
"rand 0.10.1",
|
| 7360 |
"rmcp-macros",
|
|
@@ -7462,9 +7482,9 @@ dependencies = [
|
|
| 7462 |
|
| 7463 |
[[package]]
|
| 7464 |
name = "rustls"
|
| 7465 |
-
version = "0.23.
|
| 7466 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 7467 |
-
checksum = "
|
| 7468 |
dependencies = [
|
| 7469 |
"aws-lc-rs",
|
| 7470 |
"log",
|
|
@@ -7490,9 +7510,9 @@ dependencies = [
|
|
| 7490 |
|
| 7491 |
[[package]]
|
| 7492 |
name = "rustls-pki-types"
|
| 7493 |
-
version = "1.14.
|
| 7494 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 7495 |
-
checksum = "
|
| 7496 |
dependencies = [
|
| 7497 |
"web-time",
|
| 7498 |
"zeroize",
|
|
|
|
| 828 |
|
| 829 |
[[package]]
|
| 830 |
name = "blake3"
|
| 831 |
+
version = "1.8.5"
|
| 832 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 833 |
+
checksum = "0aa83c34e62843d924f905e0f5c866eb1dd6545fc4d719e803d9ba6030371fce"
|
| 834 |
dependencies = [
|
| 835 |
"arrayref",
|
| 836 |
"arrayvec",
|
|
|
|
| 1007 |
[[package]]
|
| 1008 |
name = "candle-core"
|
| 1009 |
version = "0.9.2"
|
| 1010 |
+
source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
|
| 1011 |
dependencies = [
|
| 1012 |
"byteorder",
|
| 1013 |
"candle-kernels",
|
|
|
|
| 1033 |
"zip 7.2.0",
|
| 1034 |
]
|
| 1035 |
|
| 1036 |
+
[[package]]
|
| 1037 |
+
name = "candle-flash-attn"
|
| 1038 |
+
version = "0.9.2"
|
| 1039 |
+
source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
|
| 1040 |
+
dependencies = [
|
| 1041 |
+
"anyhow",
|
| 1042 |
+
"candle-core",
|
| 1043 |
+
"candle-flash-attn-build",
|
| 1044 |
+
"half",
|
| 1045 |
+
]
|
| 1046 |
+
|
| 1047 |
+
[[package]]
|
| 1048 |
+
name = "candle-flash-attn-build"
|
| 1049 |
+
version = "0.9.2"
|
| 1050 |
+
source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
|
| 1051 |
+
dependencies = [
|
| 1052 |
+
"anyhow",
|
| 1053 |
+
]
|
| 1054 |
+
|
| 1055 |
[[package]]
|
| 1056 |
name = "candle-kernels"
|
| 1057 |
version = "0.9.2"
|
| 1058 |
+
source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
|
| 1059 |
dependencies = [
|
| 1060 |
"bindgen_cuda",
|
| 1061 |
]
|
|
|
|
| 1063 |
[[package]]
|
| 1064 |
name = "candle-metal-kernels"
|
| 1065 |
version = "0.9.2"
|
| 1066 |
+
source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
|
| 1067 |
dependencies = [
|
| 1068 |
"half",
|
| 1069 |
"objc2",
|
|
|
|
| 1077 |
[[package]]
|
| 1078 |
name = "candle-nn"
|
| 1079 |
version = "0.9.2"
|
| 1080 |
+
source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
|
| 1081 |
dependencies = [
|
| 1082 |
"candle-core",
|
| 1083 |
"candle-metal-kernels",
|
|
|
|
| 1094 |
[[package]]
|
| 1095 |
name = "candle-transformers"
|
| 1096 |
version = "0.9.2"
|
| 1097 |
+
source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
|
| 1098 |
dependencies = [
|
| 1099 |
"byteorder",
|
| 1100 |
"candle-core",
|
|
|
|
| 1112 |
[[package]]
|
| 1113 |
name = "candle-ug"
|
| 1114 |
version = "0.9.2"
|
| 1115 |
+
source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
|
| 1116 |
dependencies = [
|
| 1117 |
"ug",
|
| 1118 |
"ug-cuda",
|
|
|
|
| 1169 |
|
| 1170 |
[[package]]
|
| 1171 |
name = "cc"
|
| 1172 |
+
version = "1.2.61"
|
| 1173 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 1174 |
+
checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d"
|
| 1175 |
dependencies = [
|
| 1176 |
"find-msvc-tools",
|
| 1177 |
"jobserver",
|
|
|
|
| 1809 |
|
| 1810 |
[[package]]
|
| 1811 |
name = "data-encoding"
|
| 1812 |
+
version = "2.11.0"
|
| 1813 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 1814 |
+
checksum = "a4ae5f15dda3c708c0ade84bfee31ccab44a3da4f88015ed22f63732abe300c8"
|
| 1815 |
|
| 1816 |
[[package]]
|
| 1817 |
name = "debugid"
|
|
|
|
| 2097 |
|
| 2098 |
[[package]]
|
| 2099 |
name = "embed-resource"
|
| 2100 |
+
version = "3.0.9"
|
| 2101 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 2102 |
+
checksum = "c31a88c8d26de40ed18fe748c547845aa39de1db3afd958f8cb91579f3644bcb"
|
| 2103 |
dependencies = [
|
| 2104 |
"cc",
|
| 2105 |
"memchr",
|
| 2106 |
"rustc_version",
|
| 2107 |
+
"toml 1.1.2+spec-1.1.0",
|
| 2108 |
"vswhom",
|
| 2109 |
"winreg 0.55.0",
|
| 2110 |
]
|
|
|
|
| 4696 |
dependencies = [
|
| 4697 |
"anyhow",
|
| 4698 |
"candle-core",
|
| 4699 |
+
"candle-flash-attn",
|
| 4700 |
"candle-nn",
|
| 4701 |
"candle-transformers",
|
| 4702 |
"clap",
|
|
|
|
| 4886 |
|
| 4887 |
[[package]]
|
| 4888 |
name = "libc"
|
| 4889 |
+
version = "0.2.186"
|
| 4890 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 4891 |
+
checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66"
|
| 4892 |
|
| 4893 |
[[package]]
|
| 4894 |
name = "libfuzzer-sys"
|
|
|
|
| 6100 |
|
| 6101 |
[[package]]
|
| 6102 |
name = "pastey"
|
| 6103 |
+
version = "0.2.2"
|
| 6104 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 6105 |
+
checksum = "c5a797f0e07bdf071d15742978fc3128ec6c22891c31a3a931513263904c982a"
|
| 6106 |
|
| 6107 |
[[package]]
|
| 6108 |
name = "pathdiff"
|
|
|
|
| 7374 |
"http 1.4.0",
|
| 7375 |
"http-body",
|
| 7376 |
"http-body-util",
|
| 7377 |
+
"pastey 0.2.2",
|
| 7378 |
"pin-project-lite",
|
| 7379 |
"rand 0.10.1",
|
| 7380 |
"rmcp-macros",
|
|
|
|
| 7482 |
|
| 7483 |
[[package]]
|
| 7484 |
name = "rustls"
|
| 7485 |
+
version = "0.23.39"
|
| 7486 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 7487 |
+
checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e"
|
| 7488 |
dependencies = [
|
| 7489 |
"aws-lc-rs",
|
| 7490 |
"log",
|
|
|
|
| 7510 |
|
| 7511 |
[[package]]
|
| 7512 |
name = "rustls-pki-types"
|
| 7513 |
+
version = "1.14.1"
|
| 7514 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 7515 |
+
checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9"
|
| 7516 |
dependencies = [
|
| 7517 |
"web-time",
|
| 7518 |
"zeroize",
|
Cargo.toml
CHANGED
|
@@ -44,6 +44,7 @@ koharu-rpc = { path = "koharu-rpc", default-features = false }
|
|
| 44 |
candle-transformers = "=0.9.2"
|
| 45 |
candle-core = "=0.9.2"
|
| 46 |
candle-nn = "=0.9.2"
|
|
|
|
| 47 |
hf-hub = "0.5"
|
| 48 |
image = "0.25"
|
| 49 |
anyhow = "1.0"
|
|
@@ -102,7 +103,9 @@ cudarc = { version = "0.19.4", features = [
|
|
| 102 |
"cublas",
|
| 103 |
"cublaslt",
|
| 104 |
"curand",
|
|
|
|
| 105 |
"driver",
|
|
|
|
| 106 |
"nvrtc",
|
| 107 |
"f16",
|
| 108 |
"f8",
|
|
@@ -166,9 +169,10 @@ natord = "1.0.9"
|
|
| 166 |
sentry = { version = "0.47", features = ["tracing"] }
|
| 167 |
|
| 168 |
[patch.crates-io]
|
| 169 |
-
candle-transformers = { git = "https://github.com/mayocream/candle", branch = "
|
| 170 |
-
candle-core = { git = "https://github.com/mayocream/candle", branch = "
|
| 171 |
-
candle-nn = { git = "https://github.com/mayocream/candle", branch = "
|
|
|
|
| 172 |
ug = { git = "https://github.com/mayocream/ug", branch = "cuda-dynamic-loading" }
|
| 173 |
ug-cuda = { git = "https://github.com/mayocream/ug", branch = "cuda-dynamic-loading" }
|
| 174 |
|
|
|
|
| 44 |
candle-transformers = "=0.9.2"
|
| 45 |
candle-core = "=0.9.2"
|
| 46 |
candle-nn = "=0.9.2"
|
| 47 |
+
candle-flash-attn = "=0.9.2"
|
| 48 |
hf-hub = "0.5"
|
| 49 |
image = "0.25"
|
| 50 |
anyhow = "1.0"
|
|
|
|
| 103 |
"cublas",
|
| 104 |
"cublaslt",
|
| 105 |
"curand",
|
| 106 |
+
"cudnn",
|
| 107 |
"driver",
|
| 108 |
+
"dynamic-loading",
|
| 109 |
"nvrtc",
|
| 110 |
"f16",
|
| 111 |
"f8",
|
|
|
|
| 169 |
sentry = { version = "0.47", features = ["tracing"] }
|
| 170 |
|
| 171 |
[patch.crates-io]
|
| 172 |
+
candle-transformers = { git = "https://github.com/mayocream/candle", branch = "flash-attn" }
|
| 173 |
+
candle-core = { git = "https://github.com/mayocream/candle", branch = "flash-attn" }
|
| 174 |
+
candle-nn = { git = "https://github.com/mayocream/candle", branch = "flash-attn" }
|
| 175 |
+
candle-flash-attn = { git = "https://github.com/mayocream/candle", branch = "flash-attn" }
|
| 176 |
ug = { git = "https://github.com/mayocream/ug", branch = "cuda-dynamic-loading" }
|
| 177 |
ug-cuda = { git = "https://github.com/mayocream/ug", branch = "cuda-dynamic-loading" }
|
| 178 |
|
koharu-ml/Cargo.toml
CHANGED
|
@@ -20,6 +20,7 @@ imageproc = { workspace = true }
|
|
| 20 |
candle-core = { workspace = true }
|
| 21 |
candle-transformers = { workspace = true }
|
| 22 |
candle-nn = { workspace = true }
|
|
|
|
| 23 |
tokenizers = { workspace = true }
|
| 24 |
serde = { workspace = true }
|
| 25 |
serde_json = { workspace = true }
|
|
@@ -44,9 +45,14 @@ objc2-foundation = { workspace = true, optional = true }
|
|
| 44 |
[features]
|
| 45 |
cuda = [
|
| 46 |
"candle-core/cuda",
|
|
|
|
| 47 |
"candle-nn/cuda",
|
|
|
|
| 48 |
"candle-transformers/cuda",
|
|
|
|
| 49 |
"cudarc",
|
|
|
|
|
|
|
| 50 |
]
|
| 51 |
metal = [
|
| 52 |
"candle-core/metal",
|
|
|
|
| 20 |
candle-core = { workspace = true }
|
| 21 |
candle-transformers = { workspace = true }
|
| 22 |
candle-nn = { workspace = true }
|
| 23 |
+
candle-flash-attn = { workspace = true, optional = true }
|
| 24 |
tokenizers = { workspace = true }
|
| 25 |
serde = { workspace = true }
|
| 26 |
serde_json = { workspace = true }
|
|
|
|
| 45 |
[features]
|
| 46 |
cuda = [
|
| 47 |
"candle-core/cuda",
|
| 48 |
+
"candle-core/cudnn",
|
| 49 |
"candle-nn/cuda",
|
| 50 |
+
"candle-nn/cudnn",
|
| 51 |
"candle-transformers/cuda",
|
| 52 |
+
"candle-transformers/cudnn",
|
| 53 |
"cudarc",
|
| 54 |
+
"candle-flash-attn",
|
| 55 |
+
"candle-flash-attn/cudnn",
|
| 56 |
]
|
| 57 |
metal = [
|
| 58 |
"candle-core/metal",
|
koharu-ml/src/flux2_klein/mod.rs
CHANGED
|
@@ -191,6 +191,7 @@ impl Flux2Klein {
|
|
| 191 |
return Ok(image.clone());
|
| 192 |
}
|
| 193 |
|
|
|
|
| 194 |
let (latents, packed_h, packed_w, size) = {
|
| 195 |
let (rgb, size) = prepare_rgb_image(image, options.max_pixels);
|
| 196 |
let image_latents = self.encode_image_latents(&rgb)?;
|
|
@@ -226,6 +227,7 @@ impl Flux2Klein {
|
|
| 226 |
)?;
|
| 227 |
}
|
| 228 |
let condition_latents = condition_latents.to_dtype(transformer_dtype)?;
|
|
|
|
| 229 |
|
| 230 |
let mut scheduler =
|
| 231 |
FlowMatchScheduler::new(options.num_inference_steps, packed_h * packed_w);
|
|
@@ -236,6 +238,9 @@ impl Flux2Klein {
|
|
| 236 |
let initial_timestep = timesteps[start_index];
|
| 237 |
let mut latents =
|
| 238 |
pack_latents(&scheduler.scale_noise(&image_latents, initial_timestep, &noise)?)?;
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
for step_idx in start_index..timesteps.len() {
|
| 241 |
let timestep = Tensor::from_vec(
|
|
@@ -250,7 +255,6 @@ impl Flux2Klein {
|
|
| 250 |
],
|
| 251 |
1,
|
| 252 |
)?;
|
| 253 |
-
let img_ids = Tensor::cat(&[latent_ids.clone(), condition_ids.clone()], 1)?;
|
| 254 |
let noise_pred = self.transformer.forward(
|
| 255 |
&latent_model_input,
|
| 256 |
&img_ids,
|
|
@@ -258,10 +262,15 @@ impl Flux2Klein {
|
|
| 258 |
&text_ids,
|
| 259 |
×tep,
|
| 260 |
)?;
|
|
|
|
|
|
|
| 261 |
let noise_pred = noise_pred
|
| 262 |
.narrow(1, 0, latents.dim(1)?)?
|
| 263 |
.to_dtype(DType::F32)?;
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
| 265 |
}
|
| 266 |
|
| 267 |
(latents, packed_h, packed_w, size)
|
|
@@ -322,6 +331,7 @@ impl Flux2Klein {
|
|
| 322 |
reference_image: Option<&DynamicImage>,
|
| 323 |
options: &Flux2InpaintOptions,
|
| 324 |
) -> Result<DynamicImage> {
|
|
|
|
| 325 |
let (latents, packed_h, packed_w, size) = {
|
| 326 |
let (rgb, size) = prepare_rgb_image(image, options.max_pixels);
|
| 327 |
let resized_mask = expand_mask(
|
|
@@ -362,6 +372,7 @@ impl Flux2Klein {
|
|
| 362 |
)?;
|
| 363 |
}
|
| 364 |
let condition_latents = condition_latents.to_dtype(transformer_dtype)?;
|
|
|
|
| 365 |
|
| 366 |
let mut scheduler =
|
| 367 |
FlowMatchScheduler::new(options.num_inference_steps, packed_h * packed_w);
|
|
@@ -373,6 +384,8 @@ impl Flux2Klein {
|
|
| 373 |
let initial_timestep = timesteps[start_index];
|
| 374 |
let mut latents =
|
| 375 |
pack_latents(&scheduler.scale_noise(&image_latents, initial_timestep, &noise)?)?;
|
|
|
|
|
|
|
| 376 |
|
| 377 |
for step_idx in start_index..timesteps.len() {
|
| 378 |
let timestep = Tensor::from_vec(
|
|
@@ -387,7 +400,6 @@ impl Flux2Klein {
|
|
| 387 |
],
|
| 388 |
1,
|
| 389 |
)?;
|
| 390 |
-
let img_ids = Tensor::cat(&[latent_ids.clone(), condition_ids.clone()], 1)?;
|
| 391 |
let noise_pred = self.transformer.forward(
|
| 392 |
&latent_model_input,
|
| 393 |
&img_ids,
|
|
@@ -395,10 +407,15 @@ impl Flux2Klein {
|
|
| 395 |
&text_ids,
|
| 396 |
×tep,
|
| 397 |
)?;
|
|
|
|
|
|
|
| 398 |
let noise_pred = noise_pred
|
| 399 |
.narrow(1, 0, latents.dim(1)?)?
|
| 400 |
.to_dtype(DType::F32)?;
|
| 401 |
-
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
let init_latents = if step_idx + 1 < timesteps.len() {
|
| 404 |
scheduler.scale_noise(
|
|
@@ -409,9 +426,11 @@ impl Flux2Klein {
|
|
| 409 |
} else {
|
| 410 |
image_latents_packed.clone()
|
| 411 |
};
|
| 412 |
-
let
|
| 413 |
-
latents = (keep_mask.broadcast_mul(&init_latents)?
|
| 414 |
+ latent_mask.broadcast_mul(&latents)?)?;
|
|
|
|
|
|
|
|
|
|
| 415 |
}
|
| 416 |
|
| 417 |
(latents, packed_h, packed_w, size)
|
|
@@ -457,10 +476,53 @@ impl Flux2Klein {
|
|
| 457 |
}
|
| 458 |
}
|
| 459 |
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
DType::F32
|
| 462 |
}
|
| 463 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
fn inpaint_crop_bounds(
|
| 465 |
image: &DynamicImage,
|
| 466 |
mask: &DynamicImage,
|
|
|
|
| 191 |
return Ok(image.clone());
|
| 192 |
}
|
| 193 |
|
| 194 |
+
let _cuda_cleanup = CudaTemporaryMemoryCleanup::new(&self.device);
|
| 195 |
let (latents, packed_h, packed_w, size) = {
|
| 196 |
let (rgb, size) = prepare_rgb_image(image, options.max_pixels);
|
| 197 |
let image_latents = self.encode_image_latents(&rgb)?;
|
|
|
|
| 227 |
)?;
|
| 228 |
}
|
| 229 |
let condition_latents = condition_latents.to_dtype(transformer_dtype)?;
|
| 230 |
+
let img_ids = Tensor::cat(&[latent_ids, condition_ids], 1)?;
|
| 231 |
|
| 232 |
let mut scheduler =
|
| 233 |
FlowMatchScheduler::new(options.num_inference_steps, packed_h * packed_w);
|
|
|
|
| 238 |
let initial_timestep = timesteps[start_index];
|
| 239 |
let mut latents =
|
| 240 |
pack_latents(&scheduler.scale_noise(&image_latents, initial_timestep, &noise)?)?;
|
| 241 |
+
drop(image_latents_packed);
|
| 242 |
+
drop(image_latents);
|
| 243 |
+
drop(noise);
|
| 244 |
|
| 245 |
for step_idx in start_index..timesteps.len() {
|
| 246 |
let timestep = Tensor::from_vec(
|
|
|
|
| 255 |
],
|
| 256 |
1,
|
| 257 |
)?;
|
|
|
|
| 258 |
let noise_pred = self.transformer.forward(
|
| 259 |
&latent_model_input,
|
| 260 |
&img_ids,
|
|
|
|
| 262 |
&text_ids,
|
| 263 |
×tep,
|
| 264 |
)?;
|
| 265 |
+
drop(latent_model_input);
|
| 266 |
+
drop(timestep);
|
| 267 |
let noise_pred = noise_pred
|
| 268 |
.narrow(1, 0, latents.dim(1)?)?
|
| 269 |
.to_dtype(DType::F32)?;
|
| 270 |
+
let next_latents = scheduler.step(&noise_pred, &latents)?;
|
| 271 |
+
drop(noise_pred);
|
| 272 |
+
let previous_latents = std::mem::replace(&mut latents, next_latents);
|
| 273 |
+
drop(previous_latents);
|
| 274 |
}
|
| 275 |
|
| 276 |
(latents, packed_h, packed_w, size)
|
|
|
|
| 331 |
reference_image: Option<&DynamicImage>,
|
| 332 |
options: &Flux2InpaintOptions,
|
| 333 |
) -> Result<DynamicImage> {
|
| 334 |
+
let _cuda_cleanup = CudaTemporaryMemoryCleanup::new(&self.device);
|
| 335 |
let (latents, packed_h, packed_w, size) = {
|
| 336 |
let (rgb, size) = prepare_rgb_image(image, options.max_pixels);
|
| 337 |
let resized_mask = expand_mask(
|
|
|
|
| 372 |
)?;
|
| 373 |
}
|
| 374 |
let condition_latents = condition_latents.to_dtype(transformer_dtype)?;
|
| 375 |
+
let img_ids = Tensor::cat(&[latent_ids, condition_ids], 1)?;
|
| 376 |
|
| 377 |
let mut scheduler =
|
| 378 |
FlowMatchScheduler::new(options.num_inference_steps, packed_h * packed_w);
|
|
|
|
| 384 |
let initial_timestep = timesteps[start_index];
|
| 385 |
let mut latents =
|
| 386 |
pack_latents(&scheduler.scale_noise(&image_latents, initial_timestep, &noise)?)?;
|
| 387 |
+
let keep_mask = ((&latent_mask * -1.0)? + 1.0)?;
|
| 388 |
+
drop(noise);
|
| 389 |
|
| 390 |
for step_idx in start_index..timesteps.len() {
|
| 391 |
let timestep = Tensor::from_vec(
|
|
|
|
| 400 |
],
|
| 401 |
1,
|
| 402 |
)?;
|
|
|
|
| 403 |
let noise_pred = self.transformer.forward(
|
| 404 |
&latent_model_input,
|
| 405 |
&img_ids,
|
|
|
|
| 407 |
&text_ids,
|
| 408 |
×tep,
|
| 409 |
)?;
|
| 410 |
+
drop(latent_model_input);
|
| 411 |
+
drop(timestep);
|
| 412 |
let noise_pred = noise_pred
|
| 413 |
.narrow(1, 0, latents.dim(1)?)?
|
| 414 |
.to_dtype(DType::F32)?;
|
| 415 |
+
let next_latents = scheduler.step(&noise_pred, &latents)?;
|
| 416 |
+
drop(noise_pred);
|
| 417 |
+
let previous_latents = std::mem::replace(&mut latents, next_latents);
|
| 418 |
+
drop(previous_latents);
|
| 419 |
|
| 420 |
let init_latents = if step_idx + 1 < timesteps.len() {
|
| 421 |
scheduler.scale_noise(
|
|
|
|
| 426 |
} else {
|
| 427 |
image_latents_packed.clone()
|
| 428 |
};
|
| 429 |
+
let masked_latents = (keep_mask.broadcast_mul(&init_latents)?
|
|
|
|
| 430 |
+ latent_mask.broadcast_mul(&latents)?)?;
|
| 431 |
+
drop(init_latents);
|
| 432 |
+
let previous_latents = std::mem::replace(&mut latents, masked_latents);
|
| 433 |
+
drop(previous_latents);
|
| 434 |
}
|
| 435 |
|
| 436 |
(latents, packed_h, packed_w, size)
|
|
|
|
| 476 |
}
|
| 477 |
}
|
| 478 |
|
| 479 |
+
struct CudaTemporaryMemoryCleanup<'a> {
|
| 480 |
+
device: &'a Device,
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
impl<'a> CudaTemporaryMemoryCleanup<'a> {
|
| 484 |
+
fn new(device: &'a Device) -> Self {
|
| 485 |
+
Self { device }
|
| 486 |
+
}
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
impl Drop for CudaTemporaryMemoryCleanup<'_> {
|
| 490 |
+
fn drop(&mut self) {
|
| 491 |
+
let _ = release_cuda_temporary_memory(self.device);
|
| 492 |
+
}
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
fn transformer_dtype(device: &Device) -> DType {
|
| 496 |
+
if device.is_cuda() {
|
| 497 |
+
return DType::BF16;
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
DType::F32
|
| 501 |
}
|
| 502 |
|
| 503 |
+
fn release_cuda_temporary_memory(device: &Device) -> Result<()> {
|
| 504 |
+
device.synchronize()?;
|
| 505 |
+
|
| 506 |
+
#[cfg(feature = "cuda")]
|
| 507 |
+
if let Ok(cuda_device) = device.as_cuda_device() {
|
| 508 |
+
let stream = cuda_device.cuda_stream();
|
| 509 |
+
let context = stream.context();
|
| 510 |
+
if context.has_async_alloc() {
|
| 511 |
+
context.bind_to_thread()?;
|
| 512 |
+
let pool = unsafe {
|
| 513 |
+
candle_core::cuda::cudarc::driver::result::device::get_mem_pool(
|
| 514 |
+
context.cu_device(),
|
| 515 |
+
)?
|
| 516 |
+
};
|
| 517 |
+
unsafe {
|
| 518 |
+
candle_core::cuda::cudarc::driver::result::mem_pool::trim_to(pool, 0)?;
|
| 519 |
+
}
|
| 520 |
+
}
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
Ok(())
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
fn inpaint_crop_bounds(
|
| 527 |
image: &DynamicImage,
|
| 528 |
mask: &DynamicImage,
|
koharu-ml/src/flux2_klein/transformer.rs
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
use std::path::Path;
|
| 2 |
|
| 3 |
-
use candle_core::{D, DType, IndexOp, Module, Result, Tensor};
|
| 4 |
-
use candle_nn::{LayerNorm, RmsNorm};
|
| 5 |
-
use candle_transformers::quantized_nn::{Linear, linear_b};
|
| 6 |
use candle_transformers::quantized_var_builder::VarBuilder;
|
| 7 |
|
| 8 |
#[derive(Debug, Clone)]
|
|
@@ -32,8 +30,97 @@ impl Default for Flux2TransformerConfig {
|
|
| 32 |
}
|
| 33 |
}
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
fn qlinear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
}
|
| 38 |
|
| 39 |
fn layer_norm(dim: usize, vb: &VarBuilder) -> Result<LayerNorm> {
|
|
@@ -83,6 +170,18 @@ fn apply_rope(xs: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
|
|
| 83 |
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
| 84 |
let dim = q.dim(D::Minus1)?;
|
| 85 |
let scale = 1.0 / (dim as f64).sqrt();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
if q.device().is_metal() {
|
| 87 |
return candle_nn::ops::sdpa(q, k, v, None, false, scale as f32, 1.0);
|
| 88 |
}
|
|
@@ -107,6 +206,8 @@ fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor>
|
|
| 107 |
let q = apply_rope(q, pe)?.contiguous()?;
|
| 108 |
let k = apply_rope(k, pe)?.contiguous()?;
|
| 109 |
let xs = scaled_dot_product_attention(&q, &k, v)?;
|
|
|
|
|
|
|
| 110 |
xs.transpose(1, 2)?.flatten_from(2)
|
| 111 |
}
|
| 112 |
|
|
@@ -265,6 +366,7 @@ impl SelfAttention {
|
|
| 265 |
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
|
| 266 |
let q = q.apply(&self.norm.query_norm)?;
|
| 267 |
let k = k.apply(&self.norm.key_norm)?;
|
|
|
|
| 268 |
Ok((q, k, v))
|
| 269 |
}
|
| 270 |
}
|
|
@@ -284,7 +386,9 @@ impl Mlp {
|
|
| 284 |
}
|
| 285 |
|
| 286 |
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 287 |
-
|
|
|
|
|
|
|
| 288 |
}
|
| 289 |
}
|
| 290 |
|
|
@@ -336,8 +440,10 @@ impl DoubleStreamBlock {
|
|
| 336 |
|
| 337 |
let img_modulated = img_mod1.scale_shift(&img.apply(&self.img_norm1)?)?;
|
| 338 |
let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
|
|
|
|
| 339 |
let txt_modulated = txt_mod1.scale_shift(&txt.apply(&self.txt_norm1)?)?;
|
| 340 |
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
|
|
|
|
| 341 |
|
| 342 |
let attn = {
|
| 343 |
let q = Tensor::cat(&[&txt_q, &img_q], 2)?;
|
|
@@ -361,44 +467,31 @@ impl DoubleStreamBlock {
|
|
| 361 |
let img_attn = img_attn.apply(&self.img_attn.proj)?;
|
| 362 |
let txt_attn = txt_attn.apply(&self.txt_attn.proj)?;
|
| 363 |
drop(attn);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
drop(img_modulated);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
drop(txt_modulated);
|
| 366 |
-
|
| 367 |
-
let
|
| 368 |
-
drop(img_attn);
|
| 369 |
-
let img_mlp = img_mod2
|
| 370 |
-
.scale_shift(&img.apply(&self.img_norm2)?)?
|
| 371 |
-
.apply_fn(|xs| self.img_mlp.forward(xs))?;
|
| 372 |
-
let img = (img + img_mod2.gate(&img_mlp)?)?;
|
| 373 |
-
drop(img_mlp);
|
| 374 |
-
|
| 375 |
-
let txt = (txt + txt_mod1.gate(&txt_attn)?)?;
|
| 376 |
-
drop(txt_attn);
|
| 377 |
-
let txt_mlp = txt_mod2
|
| 378 |
-
.scale_shift(&txt.apply(&self.txt_norm2)?)?
|
| 379 |
-
.apply_fn(|xs| self.txt_mlp.forward(xs))?;
|
| 380 |
-
let txt = (txt + txt_mod2.gate(&txt_mlp)?)?;
|
| 381 |
-
drop(txt_mlp);
|
| 382 |
|
| 383 |
Ok((img, txt))
|
| 384 |
}
|
| 385 |
}
|
| 386 |
|
| 387 |
-
trait ApplyFn {
|
| 388 |
-
fn apply_fn<F>(&self, f: F) -> Result<Tensor>
|
| 389 |
-
where
|
| 390 |
-
F: FnOnce(&Tensor) -> Result<Tensor>;
|
| 391 |
-
}
|
| 392 |
-
|
| 393 |
-
impl ApplyFn for Tensor {
|
| 394 |
-
fn apply_fn<F>(&self, f: F) -> Result<Tensor>
|
| 395 |
-
where
|
| 396 |
-
F: FnOnce(&Tensor) -> Result<Tensor>,
|
| 397 |
-
{
|
| 398 |
-
f(self)
|
| 399 |
-
}
|
| 400 |
-
}
|
| 401 |
-
|
| 402 |
#[derive(Debug, Clone)]
|
| 403 |
struct SingleStreamBlock {
|
| 404 |
linear1: Linear,
|
|
@@ -432,8 +525,11 @@ impl SingleStreamBlock {
|
|
| 432 |
|
| 433 |
fn forward(&self, xs: &Tensor, mods: &[ModulationOut], pe: &Tensor) -> Result<Tensor> {
|
| 434 |
let mod_ = &mods[0];
|
| 435 |
-
let
|
|
|
|
|
|
|
| 436 |
let qkv_mlp = x_mod.apply(&self.linear1)?;
|
|
|
|
| 437 |
let qkv = qkv_mlp.narrow(D::Minus1, 0, 3 * self.hidden_size)?;
|
| 438 |
let (b, len, _) = qkv.dims3()?;
|
| 439 |
let qkv = qkv.reshape((b, len, 3, self.num_heads, ()))?;
|
|
@@ -441,6 +537,8 @@ impl SingleStreamBlock {
|
|
| 441 |
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
|
| 442 |
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
|
| 443 |
let mlp = qkv_mlp.narrow(D::Minus1, 3 * self.hidden_size, self.mlp_size * 2)?;
|
|
|
|
|
|
|
| 444 |
let q = q.apply(&self.norm.query_norm)?;
|
| 445 |
let k = k.apply(&self.norm.key_norm)?;
|
| 446 |
let attn = attention(&q, &k, &v, pe)?;
|
|
@@ -448,10 +546,13 @@ impl SingleStreamBlock {
|
|
| 448 |
drop(k);
|
| 449 |
drop(v);
|
| 450 |
let mlp = swiglu(&mlp)?;
|
| 451 |
-
let output = Tensor::cat(&[&attn, &mlp], D::Minus1)?
|
| 452 |
drop(attn);
|
| 453 |
drop(mlp);
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
| 455 |
}
|
| 456 |
}
|
| 457 |
|
|
@@ -585,6 +686,7 @@ impl Flux2Transformer {
|
|
| 585 |
let dtype = img.dtype();
|
| 586 |
let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
|
| 587 |
let pe = self.pe_embedder.forward(&ids)?;
|
|
|
|
| 588 |
let mut img = img.apply(&self.img_in)?;
|
| 589 |
let mut txt = txt.apply(&self.txt_in)?;
|
| 590 |
let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
|
|
@@ -595,14 +697,22 @@ impl Flux2Transformer {
|
|
| 595 |
for block in &self.double_blocks {
|
| 596 |
(img, txt) = block.forward(&img, &txt, &ds_img_mods, &ds_txt_mods, &pe)?;
|
| 597 |
}
|
|
|
|
|
|
|
| 598 |
let txt_len = txt.dim(1)?;
|
| 599 |
let img_len = img.dim(1)?;
|
| 600 |
let mut xs = Tensor::cat(&[&txt, &img], 1)?;
|
|
|
|
|
|
|
| 601 |
for block in &self.single_blocks {
|
| 602 |
xs = block.forward(&xs, &ss_mods, &pe)?;
|
| 603 |
}
|
|
|
|
|
|
|
| 604 |
let img = xs.narrow(1, txt_len, img_len)?;
|
| 605 |
-
self.final_layer.forward(&img, &vec_)
|
|
|
|
|
|
|
| 606 |
}
|
| 607 |
|
| 608 |
pub fn in_channels(&self) -> usize {
|
|
|
|
| 1 |
use std::path::Path;
|
| 2 |
|
| 3 |
+
use candle_core::{D, DType, IndexOp, Module, Result, Tensor, quantized::QMatMul};
|
|
|
|
|
|
|
| 4 |
use candle_transformers::quantized_var_builder::VarBuilder;
|
| 5 |
|
| 6 |
#[derive(Debug, Clone)]
|
|
|
|
| 30 |
}
|
| 31 |
}
|
| 32 |
|
| 33 |
+
#[derive(Debug, Clone)]
|
| 34 |
+
struct Linear {
|
| 35 |
+
weight: QMatMul,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
impl Module for Linear {
|
| 39 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 40 |
+
let dtype = xs.dtype();
|
| 41 |
+
let xs = if should_promote_for_cuda(xs) {
|
| 42 |
+
xs.to_dtype(DType::F32)?
|
| 43 |
+
} else {
|
| 44 |
+
xs.clone()
|
| 45 |
+
};
|
| 46 |
+
let ys = xs.apply(&self.weight)?;
|
| 47 |
+
if ys.dtype() != dtype && matches!(dtype, DType::BF16 | DType::F16) {
|
| 48 |
+
ys.to_dtype(dtype)
|
| 49 |
+
} else {
|
| 50 |
+
Ok(ys)
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
fn qlinear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
| 56 |
+
let weight = vb.get((out_dim, in_dim), "weight")?;
|
| 57 |
+
Ok(Linear {
|
| 58 |
+
weight: QMatMul::from_arc(weight)?,
|
| 59 |
+
})
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
#[derive(Debug, Clone)]
|
| 63 |
+
struct LayerNorm {
|
| 64 |
+
inner: candle_nn::LayerNorm,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
impl LayerNorm {
|
| 68 |
+
fn new_no_bias(weight: Tensor, eps: f64) -> Self {
|
| 69 |
+
Self {
|
| 70 |
+
inner: candle_nn::LayerNorm::new_no_bias(weight, eps),
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
impl Module for LayerNorm {
|
| 76 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 77 |
+
let dtype = xs.dtype();
|
| 78 |
+
let xs = if should_promote_for_cuda(xs) {
|
| 79 |
+
xs.to_dtype(DType::F32)?
|
| 80 |
+
} else {
|
| 81 |
+
xs.clone()
|
| 82 |
+
};
|
| 83 |
+
let ys = xs.apply(&self.inner)?;
|
| 84 |
+
if ys.dtype() != dtype && matches!(dtype, DType::BF16 | DType::F16) {
|
| 85 |
+
ys.to_dtype(dtype)
|
| 86 |
+
} else {
|
| 87 |
+
Ok(ys)
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
#[derive(Debug, Clone)]
|
| 93 |
+
struct RmsNorm {
|
| 94 |
+
inner: candle_nn::RmsNorm,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
impl RmsNorm {
|
| 98 |
+
fn new(weight: Tensor, eps: f64) -> Self {
|
| 99 |
+
Self {
|
| 100 |
+
inner: candle_nn::RmsNorm::new(weight, eps),
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
impl Module for RmsNorm {
|
| 106 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 107 |
+
let dtype = xs.dtype();
|
| 108 |
+
let xs = if should_promote_for_cuda(xs) {
|
| 109 |
+
xs.to_dtype(DType::F32)?
|
| 110 |
+
} else {
|
| 111 |
+
xs.clone()
|
| 112 |
+
};
|
| 113 |
+
let ys = xs.apply(&self.inner)?;
|
| 114 |
+
if ys.dtype() != dtype && matches!(dtype, DType::BF16 | DType::F16) {
|
| 115 |
+
ys.to_dtype(dtype)
|
| 116 |
+
} else {
|
| 117 |
+
Ok(ys)
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
fn should_promote_for_cuda(xs: &Tensor) -> bool {
|
| 123 |
+
xs.device().is_cuda() && matches!(xs.dtype(), DType::BF16 | DType::F16)
|
| 124 |
}
|
| 125 |
|
| 126 |
fn layer_norm(dim: usize, vb: &VarBuilder) -> Result<LayerNorm> {
|
|
|
|
| 170 |
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
| 171 |
let dim = q.dim(D::Minus1)?;
|
| 172 |
let scale = 1.0 / (dim as f64).sqrt();
|
| 173 |
+
#[cfg(feature = "cuda")]
|
| 174 |
+
if q.device().is_cuda() {
|
| 175 |
+
let q = q.transpose(1, 2)?.contiguous()?;
|
| 176 |
+
let k = k.transpose(1, 2)?.contiguous()?;
|
| 177 |
+
let v = v.transpose(1, 2)?.contiguous()?;
|
| 178 |
+
let xs = candle_flash_attn::flash_attn(&q, &k, &v, scale as f32, false)?;
|
| 179 |
+
drop(q);
|
| 180 |
+
drop(k);
|
| 181 |
+
drop(v);
|
| 182 |
+
return xs.transpose(1, 2);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
if q.device().is_metal() {
|
| 186 |
return candle_nn::ops::sdpa(q, k, v, None, false, scale as f32, 1.0);
|
| 187 |
}
|
|
|
|
| 206 |
let q = apply_rope(q, pe)?.contiguous()?;
|
| 207 |
let k = apply_rope(k, pe)?.contiguous()?;
|
| 208 |
let xs = scaled_dot_product_attention(&q, &k, v)?;
|
| 209 |
+
drop(q);
|
| 210 |
+
drop(k);
|
| 211 |
xs.transpose(1, 2)?.flatten_from(2)
|
| 212 |
}
|
| 213 |
|
|
|
|
| 366 |
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
|
| 367 |
let q = q.apply(&self.norm.query_norm)?;
|
| 368 |
let k = k.apply(&self.norm.key_norm)?;
|
| 369 |
+
drop(qkv);
|
| 370 |
Ok((q, k, v))
|
| 371 |
}
|
| 372 |
}
|
|
|
|
| 386 |
}
|
| 387 |
|
| 388 |
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 389 |
+
let xs = xs.apply(&self.lin1)?;
|
| 390 |
+
let xs = swiglu(&xs)?;
|
| 391 |
+
xs.apply(&self.lin2)
|
| 392 |
}
|
| 393 |
}
|
| 394 |
|
|
|
|
| 440 |
|
| 441 |
let img_modulated = img_mod1.scale_shift(&img.apply(&self.img_norm1)?)?;
|
| 442 |
let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
|
| 443 |
+
drop(img_modulated);
|
| 444 |
let txt_modulated = txt_mod1.scale_shift(&txt.apply(&self.txt_norm1)?)?;
|
| 445 |
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
|
| 446 |
+
drop(txt_modulated);
|
| 447 |
|
| 448 |
let attn = {
|
| 449 |
let q = Tensor::cat(&[&txt_q, &img_q], 2)?;
|
|
|
|
| 467 |
let img_attn = img_attn.apply(&self.img_attn.proj)?;
|
| 468 |
let txt_attn = txt_attn.apply(&self.txt_attn.proj)?;
|
| 469 |
drop(attn);
|
| 470 |
+
|
| 471 |
+
let img_attn = img_mod1.gate(&img_attn)?;
|
| 472 |
+
let img = (img + img_attn)?;
|
| 473 |
+
let img_normed = img.apply(&self.img_norm2)?;
|
| 474 |
+
let img_modulated = img_mod2.scale_shift(&img_normed)?;
|
| 475 |
+
drop(img_normed);
|
| 476 |
+
let img_mlp = self.img_mlp.forward(&img_modulated)?;
|
| 477 |
drop(img_modulated);
|
| 478 |
+
let img_mlp = img_mod2.gate(&img_mlp)?;
|
| 479 |
+
let img = (img + img_mlp)?;
|
| 480 |
+
|
| 481 |
+
let txt_attn = txt_mod1.gate(&txt_attn)?;
|
| 482 |
+
let txt = (txt + txt_attn)?;
|
| 483 |
+
let txt_normed = txt.apply(&self.txt_norm2)?;
|
| 484 |
+
let txt_modulated = txt_mod2.scale_shift(&txt_normed)?;
|
| 485 |
+
drop(txt_normed);
|
| 486 |
+
let txt_mlp = self.txt_mlp.forward(&txt_modulated)?;
|
| 487 |
drop(txt_modulated);
|
| 488 |
+
let txt_mlp = txt_mod2.gate(&txt_mlp)?;
|
| 489 |
+
let txt = (txt + txt_mlp)?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
Ok((img, txt))
|
| 492 |
}
|
| 493 |
}
|
| 494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
#[derive(Debug, Clone)]
|
| 496 |
struct SingleStreamBlock {
|
| 497 |
linear1: Linear,
|
|
|
|
| 525 |
|
| 526 |
fn forward(&self, xs: &Tensor, mods: &[ModulationOut], pe: &Tensor) -> Result<Tensor> {
|
| 527 |
let mod_ = &mods[0];
|
| 528 |
+
let x_normed = xs.apply(&self.pre_norm)?;
|
| 529 |
+
let x_mod = mod_.scale_shift(&x_normed)?;
|
| 530 |
+
drop(x_normed);
|
| 531 |
let qkv_mlp = x_mod.apply(&self.linear1)?;
|
| 532 |
+
drop(x_mod);
|
| 533 |
let qkv = qkv_mlp.narrow(D::Minus1, 0, 3 * self.hidden_size)?;
|
| 534 |
let (b, len, _) = qkv.dims3()?;
|
| 535 |
let qkv = qkv.reshape((b, len, 3, self.num_heads, ()))?;
|
|
|
|
| 537 |
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
|
| 538 |
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
|
| 539 |
let mlp = qkv_mlp.narrow(D::Minus1, 3 * self.hidden_size, self.mlp_size * 2)?;
|
| 540 |
+
drop(qkv_mlp);
|
| 541 |
+
drop(qkv);
|
| 542 |
let q = q.apply(&self.norm.query_norm)?;
|
| 543 |
let k = k.apply(&self.norm.key_norm)?;
|
| 544 |
let attn = attention(&q, &k, &v, pe)?;
|
|
|
|
| 546 |
drop(k);
|
| 547 |
drop(v);
|
| 548 |
let mlp = swiglu(&mlp)?;
|
| 549 |
+
let output = Tensor::cat(&[&attn, &mlp], D::Minus1)?;
|
| 550 |
drop(attn);
|
| 551 |
drop(mlp);
|
| 552 |
+
let output = output.apply(&self.linear2)?;
|
| 553 |
+
let gated = mod_.gate(&output)?;
|
| 554 |
+
drop(output);
|
| 555 |
+
xs + gated
|
| 556 |
}
|
| 557 |
}
|
| 558 |
|
|
|
|
| 686 |
let dtype = img.dtype();
|
| 687 |
let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
|
| 688 |
let pe = self.pe_embedder.forward(&ids)?;
|
| 689 |
+
drop(ids);
|
| 690 |
let mut img = img.apply(&self.img_in)?;
|
| 691 |
let mut txt = txt.apply(&self.txt_in)?;
|
| 692 |
let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
|
|
|
|
| 697 |
for block in &self.double_blocks {
|
| 698 |
(img, txt) = block.forward(&img, &txt, &ds_img_mods, &ds_txt_mods, &pe)?;
|
| 699 |
}
|
| 700 |
+
drop(ds_img_mods);
|
| 701 |
+
drop(ds_txt_mods);
|
| 702 |
let txt_len = txt.dim(1)?;
|
| 703 |
let img_len = img.dim(1)?;
|
| 704 |
let mut xs = Tensor::cat(&[&txt, &img], 1)?;
|
| 705 |
+
drop(txt);
|
| 706 |
+
drop(img);
|
| 707 |
for block in &self.single_blocks {
|
| 708 |
xs = block.forward(&xs, &ss_mods, &pe)?;
|
| 709 |
}
|
| 710 |
+
drop(ss_mods);
|
| 711 |
+
drop(pe);
|
| 712 |
let img = xs.narrow(1, txt_len, img_len)?;
|
| 713 |
+
let xs = self.final_layer.forward(&img, &vec_)?;
|
| 714 |
+
drop(img);
|
| 715 |
+
Ok(xs)
|
| 716 |
}
|
| 717 |
|
| 718 |
pub fn in_channels(&self) -> usize {
|
koharu-ml/src/flux2_klein/vae.rs
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
use candle_core::{D, Module, Result, Tensor};
|
| 2 |
use candle_nn::{Conv2d, Conv2dConfig, GroupNorm, VarBuilder, conv2d, group_norm};
|
| 3 |
|
| 4 |
use super::latents::{patchify_latents, unpatchify_latents};
|
|
@@ -30,6 +30,15 @@ impl Default for Flux2VaeConfig {
|
|
| 30 |
}
|
| 31 |
}
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
| 34 |
let dim = q.dim(D::Minus1)?;
|
| 35 |
let scale = 1.0 / (dim as f64).sqrt();
|
|
@@ -113,10 +122,7 @@ impl ResnetBlock2D {
|
|
| 113 |
num_groups: usize,
|
| 114 |
vb: VarBuilder,
|
| 115 |
) -> Result<Self> {
|
| 116 |
-
let conv_cfg =
|
| 117 |
-
padding: 1,
|
| 118 |
-
..Default::default()
|
| 119 |
-
};
|
| 120 |
let norm1 = group_norm(num_groups, in_channels, 1e-6, vb.pp("norm1"))?;
|
| 121 |
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vb.pp("conv1"))?;
|
| 122 |
let norm2 = group_norm(num_groups, out_channels, 1e-6, vb.pp("norm2"))?;
|
|
@@ -126,7 +132,7 @@ impl ResnetBlock2D {
|
|
| 126 |
in_channels,
|
| 127 |
out_channels,
|
| 128 |
1,
|
| 129 |
-
|
| 130 |
vb.pp("conv_shortcut"),
|
| 131 |
)?)
|
| 132 |
} else {
|
|
@@ -165,11 +171,7 @@ struct Downsample2D {
|
|
| 165 |
|
| 166 |
impl Downsample2D {
|
| 167 |
fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
|
| 168 |
-
let conv_cfg =
|
| 169 |
-
stride: 2,
|
| 170 |
-
padding: 0,
|
| 171 |
-
..Default::default()
|
| 172 |
-
};
|
| 173 |
let conv = conv2d(channels, channels, 3, conv_cfg, vb.pp("conv"))?;
|
| 174 |
Ok(Self { conv })
|
| 175 |
}
|
|
@@ -243,10 +245,7 @@ struct Upsample2D {
|
|
| 243 |
|
| 244 |
impl Upsample2D {
|
| 245 |
fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
|
| 246 |
-
let conv_cfg =
|
| 247 |
-
padding: 1,
|
| 248 |
-
..Default::default()
|
| 249 |
-
};
|
| 250 |
let conv = conv2d(channels, channels, 3, conv_cfg, vb.pp("conv"))?;
|
| 251 |
Ok(Self { conv })
|
| 252 |
}
|
|
@@ -342,10 +341,7 @@ struct Encoder {
|
|
| 342 |
|
| 343 |
impl Encoder {
|
| 344 |
fn new(cfg: &Flux2VaeConfig, vb: VarBuilder) -> Result<Self> {
|
| 345 |
-
let conv_cfg =
|
| 346 |
-
padding: 1,
|
| 347 |
-
..Default::default()
|
| 348 |
-
};
|
| 349 |
let conv_in = conv2d(
|
| 350 |
cfg.in_channels,
|
| 351 |
cfg.block_out_channels[0],
|
|
@@ -419,10 +415,7 @@ struct Decoder {
|
|
| 419 |
|
| 420 |
impl Decoder {
|
| 421 |
fn new(cfg: &Flux2VaeConfig, vb: VarBuilder) -> Result<Self> {
|
| 422 |
-
let conv_cfg =
|
| 423 |
-
padding: 1,
|
| 424 |
-
..Default::default()
|
| 425 |
-
};
|
| 426 |
let mid_channels = *cfg.decoder_block_out_channels.last().unwrap();
|
| 427 |
let conv_in = conv2d(
|
| 428 |
cfg.latent_channels,
|
|
@@ -512,14 +505,14 @@ impl Flux2Vae {
|
|
| 512 |
2 * cfg.latent_channels,
|
| 513 |
2 * cfg.latent_channels,
|
| 514 |
1,
|
| 515 |
-
|
| 516 |
vb.pp("quant_conv"),
|
| 517 |
)?;
|
| 518 |
let post_quant_conv = conv2d(
|
| 519 |
cfg.latent_channels,
|
| 520 |
cfg.latent_channels,
|
| 521 |
1,
|
| 522 |
-
|
| 523 |
vb.pp("post_quant_conv"),
|
| 524 |
)?;
|
| 525 |
let bn_running_mean = vb.get(4 * cfg.latent_channels, "bn.running_mean")?;
|
|
|
|
| 1 |
+
use candle_core::{D, Module, Result, Tensor, conv::CudnnFwdAlgo};
|
| 2 |
use candle_nn::{Conv2d, Conv2dConfig, GroupNorm, VarBuilder, conv2d, group_norm};
|
| 3 |
|
| 4 |
use super::latents::{patchify_latents, unpatchify_latents};
|
|
|
|
| 30 |
}
|
| 31 |
}
|
| 32 |
|
| 33 |
+
fn vae_conv_config(padding: usize, stride: usize) -> Conv2dConfig {
|
| 34 |
+
Conv2dConfig {
|
| 35 |
+
padding,
|
| 36 |
+
stride,
|
| 37 |
+
cudnn_fwd_algo: Some(CudnnFwdAlgo::ImplicitGemm),
|
| 38 |
+
..Default::default()
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
| 43 |
let dim = q.dim(D::Minus1)?;
|
| 44 |
let scale = 1.0 / (dim as f64).sqrt();
|
|
|
|
| 122 |
num_groups: usize,
|
| 123 |
vb: VarBuilder,
|
| 124 |
) -> Result<Self> {
|
| 125 |
+
let conv_cfg = vae_conv_config(1, 1);
|
|
|
|
|
|
|
|
|
|
| 126 |
let norm1 = group_norm(num_groups, in_channels, 1e-6, vb.pp("norm1"))?;
|
| 127 |
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vb.pp("conv1"))?;
|
| 128 |
let norm2 = group_norm(num_groups, out_channels, 1e-6, vb.pp("norm2"))?;
|
|
|
|
| 132 |
in_channels,
|
| 133 |
out_channels,
|
| 134 |
1,
|
| 135 |
+
vae_conv_config(0, 1),
|
| 136 |
vb.pp("conv_shortcut"),
|
| 137 |
)?)
|
| 138 |
} else {
|
|
|
|
| 171 |
|
| 172 |
impl Downsample2D {
|
| 173 |
fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
|
| 174 |
+
let conv_cfg = vae_conv_config(0, 2);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
let conv = conv2d(channels, channels, 3, conv_cfg, vb.pp("conv"))?;
|
| 176 |
Ok(Self { conv })
|
| 177 |
}
|
|
|
|
| 245 |
|
| 246 |
impl Upsample2D {
|
| 247 |
fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
|
| 248 |
+
let conv_cfg = vae_conv_config(1, 1);
|
|
|
|
|
|
|
|
|
|
| 249 |
let conv = conv2d(channels, channels, 3, conv_cfg, vb.pp("conv"))?;
|
| 250 |
Ok(Self { conv })
|
| 251 |
}
|
|
|
|
| 341 |
|
| 342 |
impl Encoder {
|
| 343 |
fn new(cfg: &Flux2VaeConfig, vb: VarBuilder) -> Result<Self> {
|
| 344 |
+
let conv_cfg = vae_conv_config(1, 1);
|
|
|
|
|
|
|
|
|
|
| 345 |
let conv_in = conv2d(
|
| 346 |
cfg.in_channels,
|
| 347 |
cfg.block_out_channels[0],
|
|
|
|
| 415 |
|
| 416 |
impl Decoder {
|
| 417 |
fn new(cfg: &Flux2VaeConfig, vb: VarBuilder) -> Result<Self> {
|
| 418 |
+
let conv_cfg = vae_conv_config(1, 1);
|
|
|
|
|
|
|
|
|
|
| 419 |
let mid_channels = *cfg.decoder_block_out_channels.last().unwrap();
|
| 420 |
let conv_in = conv2d(
|
| 421 |
cfg.latent_channels,
|
|
|
|
| 505 |
2 * cfg.latent_channels,
|
| 506 |
2 * cfg.latent_channels,
|
| 507 |
1,
|
| 508 |
+
vae_conv_config(0, 1),
|
| 509 |
vb.pp("quant_conv"),
|
| 510 |
)?;
|
| 511 |
let post_quant_conv = conv2d(
|
| 512 |
cfg.latent_channels,
|
| 513 |
cfg.latent_channels,
|
| 514 |
1,
|
| 515 |
+
vae_conv_config(0, 1),
|
| 516 |
vb.pp("post_quant_conv"),
|
| 517 |
)?;
|
| 518 |
let bn_running_mean = vb.get(4 * cfg.latent_channels, "bn.running_mean")?;
|
koharu-runtime/src/cuda.rs
CHANGED
|
@@ -11,6 +11,7 @@ use crate::loader::{add_runtime_search_path, preload_library};
|
|
| 11 |
const CUDA_SUCCESS: i32 = 0;
|
| 12 |
const CUDA_13_0_DRIVER_VERSION: i32 = 13000;
|
| 13 |
const CUDA_13_1_DRIVER_VERSION: i32 = 13010;
|
|
|
|
| 14 |
const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: i32 = 75;
|
| 15 |
const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: i32 = 76;
|
| 16 |
const MIN_COMPUTE_CAPABILITY: (i32, i32) = (7, 5); // Turing (RTX 20xx) and above
|
|
@@ -64,6 +65,31 @@ const WHEELS: &[WheelSpec] = &[
|
|
| 64 |
windows_dylibs: &["curand64_10.dll"],
|
| 65 |
linux_dylibs: &["libcurand.so.10"],
|
| 66 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
];
|
| 68 |
|
| 69 |
impl CudaDriverVersion {
|
|
@@ -379,9 +405,10 @@ impl WheelSpec {
|
|
| 379 |
fn source_id() -> Result<String> {
|
| 380 |
let packages = WHEELS.iter().map(|wheel| wheel.package).collect::<Vec<_>>();
|
| 381 |
Ok(format!(
|
| 382 |
-
"cuda;platform={};wheels={}",
|
| 383 |
platform_tags()?.join(","),
|
| 384 |
-
packages.join(",")
|
|
|
|
| 385 |
))
|
| 386 |
}
|
| 387 |
|
|
@@ -458,6 +485,20 @@ mod tests {
|
|
| 458 |
}
|
| 459 |
}
|
| 460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
#[test]
|
| 462 |
fn parses_major_minor_from_driver_version() {
|
| 463 |
let version = CudaDriverVersion::from_raw(13010);
|
|
|
|
| 11 |
const CUDA_SUCCESS: i32 = 0;
|
| 12 |
const CUDA_13_0_DRIVER_VERSION: i32 = 13000;
|
| 13 |
const CUDA_13_1_DRIVER_VERSION: i32 = 13010;
|
| 14 |
+
const CUDA_EXTRACT_REVISION: u32 = 2;
|
| 15 |
const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: i32 = 75;
|
| 16 |
const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: i32 = 76;
|
| 17 |
const MIN_COMPUTE_CAPABILITY: (i32, i32) = (7, 5); // Turing (RTX 20xx) and above
|
|
|
|
| 65 |
windows_dylibs: &["curand64_10.dll"],
|
| 66 |
linux_dylibs: &["libcurand.so.10"],
|
| 67 |
},
|
| 68 |
+
WheelSpec {
|
| 69 |
+
package: "nvidia-cudnn-cu13/9.21.0.82",
|
| 70 |
+
windows_dylibs: &[
|
| 71 |
+
"cudnn64_9.dll",
|
| 72 |
+
"cudnn_adv64_9.dll",
|
| 73 |
+
"cudnn_cnn64_9.dll",
|
| 74 |
+
"cudnn_engines_precompiled64_9.dll",
|
| 75 |
+
"cudnn_engines_runtime_compiled64_9.dll",
|
| 76 |
+
"cudnn_engines_tensor_ir64_9.dll",
|
| 77 |
+
"cudnn_graph64_9.dll",
|
| 78 |
+
"cudnn_heuristic64_9.dll",
|
| 79 |
+
"cudnn_ops64_9.dll",
|
| 80 |
+
],
|
| 81 |
+
linux_dylibs: &[
|
| 82 |
+
"libcudnn.so.9",
|
| 83 |
+
"libcudnn_adv.so.9",
|
| 84 |
+
"libcudnn_cnn.so.9",
|
| 85 |
+
"libcudnn_engines_precompiled.so.9",
|
| 86 |
+
"libcudnn_engines_runtime_compiled.so.9",
|
| 87 |
+
"libcudnn_engines_tensor_ir.so.9",
|
| 88 |
+
"libcudnn_graph.so.9",
|
| 89 |
+
"libcudnn_heuristic.so.9",
|
| 90 |
+
"libcudnn_ops.so.9",
|
| 91 |
+
],
|
| 92 |
+
},
|
| 93 |
];
|
| 94 |
|
| 95 |
impl CudaDriverVersion {
|
|
|
|
| 405 |
fn source_id() -> Result<String> {
|
| 406 |
let packages = WHEELS.iter().map(|wheel| wheel.package).collect::<Vec<_>>();
|
| 407 |
Ok(format!(
|
| 408 |
+
"cuda;platform={};wheels={};extract={}",
|
| 409 |
platform_tags()?.join(","),
|
| 410 |
+
packages.join(","),
|
| 411 |
+
CUDA_EXTRACT_REVISION
|
| 412 |
))
|
| 413 |
}
|
| 414 |
|
|
|
|
| 485 |
}
|
| 486 |
}
|
| 487 |
|
| 488 |
+
#[test]
|
| 489 |
+
fn cuda_runtime_includes_cudnn() {
|
| 490 |
+
let wheel = WHEELS
|
| 491 |
+
.iter()
|
| 492 |
+
.find(|wheel| wheel.package.starts_with("nvidia-cudnn-cu13/"))
|
| 493 |
+
.expect("missing cuDNN runtime wheel");
|
| 494 |
+
|
| 495 |
+
#[cfg(target_os = "windows")]
|
| 496 |
+
assert!(wheel.dylibs().contains(&"cudnn64_9.dll"));
|
| 497 |
+
|
| 498 |
+
#[cfg(target_os = "linux")]
|
| 499 |
+
assert!(wheel.dylibs().contains(&"libcudnn.so.9"));
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
#[test]
|
| 503 |
fn parses_major_minor_from_driver_version() {
|
| 504 |
let version = CudaDriverVersion::from_raw(13010);
|
koharu/tauri.windows.conf.json
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
{
|
| 2 |
-
"$schema": "../node_modules/@tauri-apps/cli/config.schema.json",
|
| 3 |
"identifier": "Koharu",
|
| 4 |
"build": {
|
| 5 |
"features": [
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"$schema": "../node_modules/@tauri-apps/cli/config.schema.json",
|
| 3 |
"identifier": "Koharu",
|
| 4 |
"build": {
|
| 5 |
"features": [
|