Mayo commited on
Commit
aab9632
·
unverified ·
1 Parent(s): 504f8c2

perf: FLUX.2 improvements

Browse files
.cargo/config.toml CHANGED
@@ -5,4 +5,4 @@ LLAMA_CPP_TAG = "b8665"
5
  # CUDA 13.0 requires C++17
6
  NVCC_PREPEND_FLAGS = "-std=c++17"
7
  # override nvidia-smi compute capability
8
- CUDA_COMPUTE_CAP = "75"
 
5
  # CUDA 13.0 requires C++17
6
  NVCC_PREPEND_FLAGS = "-std=c++17"
7
  # override nvidia-smi compute capability
8
+ CUDA_COMPUTE_CAP = "80"
Cargo.lock CHANGED
@@ -828,9 +828,9 @@ dependencies = [
828
 
829
  [[package]]
830
  name = "blake3"
831
- version = "1.8.4"
832
  source = "registry+https://github.com/rust-lang/crates.io-index"
833
- checksum = "4d2d5991425dfd0785aed03aedcf0b321d61975c9b5b3689c774a2610ae0b51e"
834
  dependencies = [
835
  "arrayref",
836
  "arrayvec",
@@ -1007,7 +1007,7 @@ dependencies = [
1007
  [[package]]
1008
  name = "candle-core"
1009
  version = "0.9.2"
1010
- source = "git+https://github.com/mayocream/candle?branch=cuda-dynamic-loading#38be780754e954d88b63bbe1ef7e4098bbaa4c02"
1011
  dependencies = [
1012
  "byteorder",
1013
  "candle-kernels",
@@ -1033,10 +1033,29 @@ dependencies = [
1033
  "zip 7.2.0",
1034
  ]
1035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1036
  [[package]]
1037
  name = "candle-kernels"
1038
  version = "0.9.2"
1039
- source = "git+https://github.com/mayocream/candle?branch=cuda-dynamic-loading#38be780754e954d88b63bbe1ef7e4098bbaa4c02"
1040
  dependencies = [
1041
  "bindgen_cuda",
1042
  ]
@@ -1044,7 +1063,7 @@ dependencies = [
1044
  [[package]]
1045
  name = "candle-metal-kernels"
1046
  version = "0.9.2"
1047
- source = "git+https://github.com/mayocream/candle?branch=cuda-dynamic-loading#38be780754e954d88b63bbe1ef7e4098bbaa4c02"
1048
  dependencies = [
1049
  "half",
1050
  "objc2",
@@ -1058,7 +1077,7 @@ dependencies = [
1058
  [[package]]
1059
  name = "candle-nn"
1060
  version = "0.9.2"
1061
- source = "git+https://github.com/mayocream/candle?branch=cuda-dynamic-loading#38be780754e954d88b63bbe1ef7e4098bbaa4c02"
1062
  dependencies = [
1063
  "candle-core",
1064
  "candle-metal-kernels",
@@ -1075,7 +1094,7 @@ dependencies = [
1075
  [[package]]
1076
  name = "candle-transformers"
1077
  version = "0.9.2"
1078
- source = "git+https://github.com/mayocream/candle?branch=cuda-dynamic-loading#38be780754e954d88b63bbe1ef7e4098bbaa4c02"
1079
  dependencies = [
1080
  "byteorder",
1081
  "candle-core",
@@ -1093,7 +1112,7 @@ dependencies = [
1093
  [[package]]
1094
  name = "candle-ug"
1095
  version = "0.9.2"
1096
- source = "git+https://github.com/mayocream/candle?branch=cuda-dynamic-loading#38be780754e954d88b63bbe1ef7e4098bbaa4c02"
1097
  dependencies = [
1098
  "ug",
1099
  "ug-cuda",
@@ -1150,9 +1169,9 @@ dependencies = [
1150
 
1151
  [[package]]
1152
  name = "cc"
1153
- version = "1.2.60"
1154
  source = "registry+https://github.com/rust-lang/crates.io-index"
1155
- checksum = "43c5703da9466b66a946814e1adf53ea2c90f10063b86290cc9eb67ce3478a20"
1156
  dependencies = [
1157
  "find-msvc-tools",
1158
  "jobserver",
@@ -1790,9 +1809,9 @@ dependencies = [
1790
 
1791
  [[package]]
1792
  name = "data-encoding"
1793
- version = "2.10.0"
1794
  source = "registry+https://github.com/rust-lang/crates.io-index"
1795
- checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea"
1796
 
1797
  [[package]]
1798
  name = "debugid"
@@ -2078,14 +2097,14 @@ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
2078
 
2079
  [[package]]
2080
  name = "embed-resource"
2081
- version = "3.0.8"
2082
  source = "registry+https://github.com/rust-lang/crates.io-index"
2083
- checksum = "63a1d0de4f2249aa0ff5884d7080814f446bb241a559af6c170a41e878ed2d45"
2084
  dependencies = [
2085
  "cc",
2086
  "memchr",
2087
  "rustc_version",
2088
- "toml 0.9.12+spec-1.1.0",
2089
  "vswhom",
2090
  "winreg 0.55.0",
2091
  ]
@@ -4677,6 +4696,7 @@ version = "0.49.0"
4677
  dependencies = [
4678
  "anyhow",
4679
  "candle-core",
 
4680
  "candle-nn",
4681
  "candle-transformers",
4682
  "clap",
@@ -4866,9 +4886,9 @@ dependencies = [
4866
 
4867
  [[package]]
4868
  name = "libc"
4869
- version = "0.2.185"
4870
  source = "registry+https://github.com/rust-lang/crates.io-index"
4871
- checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f"
4872
 
4873
  [[package]]
4874
  name = "libfuzzer-sys"
@@ -6080,9 +6100,9 @@ checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec"
6080
 
6081
  [[package]]
6082
  name = "pastey"
6083
- version = "0.2.1"
6084
  source = "registry+https://github.com/rust-lang/crates.io-index"
6085
- checksum = "b867cad97c0791bbd3aaa6472142568c6c9e8f71937e98379f584cfb0cf35bec"
6086
 
6087
  [[package]]
6088
  name = "pathdiff"
@@ -7354,7 +7374,7 @@ dependencies = [
7354
  "http 1.4.0",
7355
  "http-body",
7356
  "http-body-util",
7357
- "pastey 0.2.1",
7358
  "pin-project-lite",
7359
  "rand 0.10.1",
7360
  "rmcp-macros",
@@ -7462,9 +7482,9 @@ dependencies = [
7462
 
7463
  [[package]]
7464
  name = "rustls"
7465
- version = "0.23.38"
7466
  source = "registry+https://github.com/rust-lang/crates.io-index"
7467
- checksum = "69f9466fb2c14ea04357e91413efb882e2a6d4a406e625449bc0a5d360d53a21"
7468
  dependencies = [
7469
  "aws-lc-rs",
7470
  "log",
@@ -7490,9 +7510,9 @@ dependencies = [
7490
 
7491
  [[package]]
7492
  name = "rustls-pki-types"
7493
- version = "1.14.0"
7494
  source = "registry+https://github.com/rust-lang/crates.io-index"
7495
- checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd"
7496
  dependencies = [
7497
  "web-time",
7498
  "zeroize",
 
828
 
829
  [[package]]
830
  name = "blake3"
831
+ version = "1.8.5"
832
  source = "registry+https://github.com/rust-lang/crates.io-index"
833
+ checksum = "0aa83c34e62843d924f905e0f5c866eb1dd6545fc4d719e803d9ba6030371fce"
834
  dependencies = [
835
  "arrayref",
836
  "arrayvec",
 
1007
  [[package]]
1008
  name = "candle-core"
1009
  version = "0.9.2"
1010
+ source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
1011
  dependencies = [
1012
  "byteorder",
1013
  "candle-kernels",
 
1033
  "zip 7.2.0",
1034
  ]
1035
 
1036
+ [[package]]
1037
+ name = "candle-flash-attn"
1038
+ version = "0.9.2"
1039
+ source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
1040
+ dependencies = [
1041
+ "anyhow",
1042
+ "candle-core",
1043
+ "candle-flash-attn-build",
1044
+ "half",
1045
+ ]
1046
+
1047
+ [[package]]
1048
+ name = "candle-flash-attn-build"
1049
+ version = "0.9.2"
1050
+ source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
1051
+ dependencies = [
1052
+ "anyhow",
1053
+ ]
1054
+
1055
  [[package]]
1056
  name = "candle-kernels"
1057
  version = "0.9.2"
1058
+ source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
1059
  dependencies = [
1060
  "bindgen_cuda",
1061
  ]
 
1063
  [[package]]
1064
  name = "candle-metal-kernels"
1065
  version = "0.9.2"
1066
+ source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
1067
  dependencies = [
1068
  "half",
1069
  "objc2",
 
1077
  [[package]]
1078
  name = "candle-nn"
1079
  version = "0.9.2"
1080
+ source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
1081
  dependencies = [
1082
  "candle-core",
1083
  "candle-metal-kernels",
 
1094
  [[package]]
1095
  name = "candle-transformers"
1096
  version = "0.9.2"
1097
+ source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
1098
  dependencies = [
1099
  "byteorder",
1100
  "candle-core",
 
1112
  [[package]]
1113
  name = "candle-ug"
1114
  version = "0.9.2"
1115
+ source = "git+https://github.com/mayocream/candle?branch=flash-attn#e7e71e18414db8de91113963beaabb6b4046a0a5"
1116
  dependencies = [
1117
  "ug",
1118
  "ug-cuda",
 
1169
 
1170
  [[package]]
1171
  name = "cc"
1172
+ version = "1.2.61"
1173
  source = "registry+https://github.com/rust-lang/crates.io-index"
1174
+ checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d"
1175
  dependencies = [
1176
  "find-msvc-tools",
1177
  "jobserver",
 
1809
 
1810
  [[package]]
1811
  name = "data-encoding"
1812
+ version = "2.11.0"
1813
  source = "registry+https://github.com/rust-lang/crates.io-index"
1814
+ checksum = "a4ae5f15dda3c708c0ade84bfee31ccab44a3da4f88015ed22f63732abe300c8"
1815
 
1816
  [[package]]
1817
  name = "debugid"
 
2097
 
2098
  [[package]]
2099
  name = "embed-resource"
2100
+ version = "3.0.9"
2101
  source = "registry+https://github.com/rust-lang/crates.io-index"
2102
+ checksum = "c31a88c8d26de40ed18fe748c547845aa39de1db3afd958f8cb91579f3644bcb"
2103
  dependencies = [
2104
  "cc",
2105
  "memchr",
2106
  "rustc_version",
2107
+ "toml 1.1.2+spec-1.1.0",
2108
  "vswhom",
2109
  "winreg 0.55.0",
2110
  ]
 
4696
  dependencies = [
4697
  "anyhow",
4698
  "candle-core",
4699
+ "candle-flash-attn",
4700
  "candle-nn",
4701
  "candle-transformers",
4702
  "clap",
 
4886
 
4887
  [[package]]
4888
  name = "libc"
4889
+ version = "0.2.186"
4890
  source = "registry+https://github.com/rust-lang/crates.io-index"
4891
+ checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66"
4892
 
4893
  [[package]]
4894
  name = "libfuzzer-sys"
 
6100
 
6101
  [[package]]
6102
  name = "pastey"
6103
+ version = "0.2.2"
6104
  source = "registry+https://github.com/rust-lang/crates.io-index"
6105
+ checksum = "c5a797f0e07bdf071d15742978fc3128ec6c22891c31a3a931513263904c982a"
6106
 
6107
  [[package]]
6108
  name = "pathdiff"
 
7374
  "http 1.4.0",
7375
  "http-body",
7376
  "http-body-util",
7377
+ "pastey 0.2.2",
7378
  "pin-project-lite",
7379
  "rand 0.10.1",
7380
  "rmcp-macros",
 
7482
 
7483
  [[package]]
7484
  name = "rustls"
7485
+ version = "0.23.39"
7486
  source = "registry+https://github.com/rust-lang/crates.io-index"
7487
+ checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e"
7488
  dependencies = [
7489
  "aws-lc-rs",
7490
  "log",
 
7510
 
7511
  [[package]]
7512
  name = "rustls-pki-types"
7513
+ version = "1.14.1"
7514
  source = "registry+https://github.com/rust-lang/crates.io-index"
7515
+ checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9"
7516
  dependencies = [
7517
  "web-time",
7518
  "zeroize",
Cargo.toml CHANGED
@@ -44,6 +44,7 @@ koharu-rpc = { path = "koharu-rpc", default-features = false }
44
  candle-transformers = "=0.9.2"
45
  candle-core = "=0.9.2"
46
  candle-nn = "=0.9.2"
 
47
  hf-hub = "0.5"
48
  image = "0.25"
49
  anyhow = "1.0"
@@ -102,7 +103,9 @@ cudarc = { version = "0.19.4", features = [
102
  "cublas",
103
  "cublaslt",
104
  "curand",
 
105
  "driver",
 
106
  "nvrtc",
107
  "f16",
108
  "f8",
@@ -166,9 +169,10 @@ natord = "1.0.9"
166
  sentry = { version = "0.47", features = ["tracing"] }
167
 
168
  [patch.crates-io]
169
- candle-transformers = { git = "https://github.com/mayocream/candle", branch = "cuda-dynamic-loading" }
170
- candle-core = { git = "https://github.com/mayocream/candle", branch = "cuda-dynamic-loading" }
171
- candle-nn = { git = "https://github.com/mayocream/candle", branch = "cuda-dynamic-loading" }
 
172
  ug = { git = "https://github.com/mayocream/ug", branch = "cuda-dynamic-loading" }
173
  ug-cuda = { git = "https://github.com/mayocream/ug", branch = "cuda-dynamic-loading" }
174
 
 
44
  candle-transformers = "=0.9.2"
45
  candle-core = "=0.9.2"
46
  candle-nn = "=0.9.2"
47
+ candle-flash-attn = "=0.9.2"
48
  hf-hub = "0.5"
49
  image = "0.25"
50
  anyhow = "1.0"
 
103
  "cublas",
104
  "cublaslt",
105
  "curand",
106
+ "cudnn",
107
  "driver",
108
+ "dynamic-loading",
109
  "nvrtc",
110
  "f16",
111
  "f8",
 
169
  sentry = { version = "0.47", features = ["tracing"] }
170
 
171
  [patch.crates-io]
172
+ candle-transformers = { git = "https://github.com/mayocream/candle", branch = "flash-attn" }
173
+ candle-core = { git = "https://github.com/mayocream/candle", branch = "flash-attn" }
174
+ candle-nn = { git = "https://github.com/mayocream/candle", branch = "flash-attn" }
175
+ candle-flash-attn = { git = "https://github.com/mayocream/candle", branch = "flash-attn" }
176
  ug = { git = "https://github.com/mayocream/ug", branch = "cuda-dynamic-loading" }
177
  ug-cuda = { git = "https://github.com/mayocream/ug", branch = "cuda-dynamic-loading" }
178
 
koharu-ml/Cargo.toml CHANGED
@@ -20,6 +20,7 @@ imageproc = { workspace = true }
20
  candle-core = { workspace = true }
21
  candle-transformers = { workspace = true }
22
  candle-nn = { workspace = true }
 
23
  tokenizers = { workspace = true }
24
  serde = { workspace = true }
25
  serde_json = { workspace = true }
@@ -44,9 +45,14 @@ objc2-foundation = { workspace = true, optional = true }
44
  [features]
45
  cuda = [
46
  "candle-core/cuda",
 
47
  "candle-nn/cuda",
 
48
  "candle-transformers/cuda",
 
49
  "cudarc",
 
 
50
  ]
51
  metal = [
52
  "candle-core/metal",
 
20
  candle-core = { workspace = true }
21
  candle-transformers = { workspace = true }
22
  candle-nn = { workspace = true }
23
+ candle-flash-attn = { workspace = true, optional = true }
24
  tokenizers = { workspace = true }
25
  serde = { workspace = true }
26
  serde_json = { workspace = true }
 
45
  [features]
46
  cuda = [
47
  "candle-core/cuda",
48
+ "candle-core/cudnn",
49
  "candle-nn/cuda",
50
+ "candle-nn/cudnn",
51
  "candle-transformers/cuda",
52
+ "candle-transformers/cudnn",
53
  "cudarc",
54
+ "candle-flash-attn",
55
+ "candle-flash-attn/cudnn",
56
  ]
57
  metal = [
58
  "candle-core/metal",
koharu-ml/src/flux2_klein/mod.rs CHANGED
@@ -191,6 +191,7 @@ impl Flux2Klein {
191
  return Ok(image.clone());
192
  }
193
 
 
194
  let (latents, packed_h, packed_w, size) = {
195
  let (rgb, size) = prepare_rgb_image(image, options.max_pixels);
196
  let image_latents = self.encode_image_latents(&rgb)?;
@@ -226,6 +227,7 @@ impl Flux2Klein {
226
  )?;
227
  }
228
  let condition_latents = condition_latents.to_dtype(transformer_dtype)?;
 
229
 
230
  let mut scheduler =
231
  FlowMatchScheduler::new(options.num_inference_steps, packed_h * packed_w);
@@ -236,6 +238,9 @@ impl Flux2Klein {
236
  let initial_timestep = timesteps[start_index];
237
  let mut latents =
238
  pack_latents(&scheduler.scale_noise(&image_latents, initial_timestep, &noise)?)?;
 
 
 
239
 
240
  for step_idx in start_index..timesteps.len() {
241
  let timestep = Tensor::from_vec(
@@ -250,7 +255,6 @@ impl Flux2Klein {
250
  ],
251
  1,
252
  )?;
253
- let img_ids = Tensor::cat(&[latent_ids.clone(), condition_ids.clone()], 1)?;
254
  let noise_pred = self.transformer.forward(
255
  &latent_model_input,
256
  &img_ids,
@@ -258,10 +262,15 @@ impl Flux2Klein {
258
  &text_ids,
259
  &timestep,
260
  )?;
 
 
261
  let noise_pred = noise_pred
262
  .narrow(1, 0, latents.dim(1)?)?
263
  .to_dtype(DType::F32)?;
264
- latents = scheduler.step(&noise_pred, &latents)?;
 
 
 
265
  }
266
 
267
  (latents, packed_h, packed_w, size)
@@ -322,6 +331,7 @@ impl Flux2Klein {
322
  reference_image: Option<&DynamicImage>,
323
  options: &Flux2InpaintOptions,
324
  ) -> Result<DynamicImage> {
 
325
  let (latents, packed_h, packed_w, size) = {
326
  let (rgb, size) = prepare_rgb_image(image, options.max_pixels);
327
  let resized_mask = expand_mask(
@@ -362,6 +372,7 @@ impl Flux2Klein {
362
  )?;
363
  }
364
  let condition_latents = condition_latents.to_dtype(transformer_dtype)?;
 
365
 
366
  let mut scheduler =
367
  FlowMatchScheduler::new(options.num_inference_steps, packed_h * packed_w);
@@ -373,6 +384,8 @@ impl Flux2Klein {
373
  let initial_timestep = timesteps[start_index];
374
  let mut latents =
375
  pack_latents(&scheduler.scale_noise(&image_latents, initial_timestep, &noise)?)?;
 
 
376
 
377
  for step_idx in start_index..timesteps.len() {
378
  let timestep = Tensor::from_vec(
@@ -387,7 +400,6 @@ impl Flux2Klein {
387
  ],
388
  1,
389
  )?;
390
- let img_ids = Tensor::cat(&[latent_ids.clone(), condition_ids.clone()], 1)?;
391
  let noise_pred = self.transformer.forward(
392
  &latent_model_input,
393
  &img_ids,
@@ -395,10 +407,15 @@ impl Flux2Klein {
395
  &text_ids,
396
  &timestep,
397
  )?;
 
 
398
  let noise_pred = noise_pred
399
  .narrow(1, 0, latents.dim(1)?)?
400
  .to_dtype(DType::F32)?;
401
- latents = scheduler.step(&noise_pred, &latents)?;
 
 
 
402
 
403
  let init_latents = if step_idx + 1 < timesteps.len() {
404
  scheduler.scale_noise(
@@ -409,9 +426,11 @@ impl Flux2Klein {
409
  } else {
410
  image_latents_packed.clone()
411
  };
412
- let keep_mask = ((&latent_mask * -1.0)? + 1.0)?;
413
- latents = (keep_mask.broadcast_mul(&init_latents)?
414
  + latent_mask.broadcast_mul(&latents)?)?;
 
 
 
415
  }
416
 
417
  (latents, packed_h, packed_w, size)
@@ -457,10 +476,53 @@ impl Flux2Klein {
457
  }
458
  }
459
 
460
- fn transformer_dtype(_device: &Device) -> DType {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  DType::F32
462
  }
463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  fn inpaint_crop_bounds(
465
  image: &DynamicImage,
466
  mask: &DynamicImage,
 
191
  return Ok(image.clone());
192
  }
193
 
194
+ let _cuda_cleanup = CudaTemporaryMemoryCleanup::new(&self.device);
195
  let (latents, packed_h, packed_w, size) = {
196
  let (rgb, size) = prepare_rgb_image(image, options.max_pixels);
197
  let image_latents = self.encode_image_latents(&rgb)?;
 
227
  )?;
228
  }
229
  let condition_latents = condition_latents.to_dtype(transformer_dtype)?;
230
+ let img_ids = Tensor::cat(&[latent_ids, condition_ids], 1)?;
231
 
232
  let mut scheduler =
233
  FlowMatchScheduler::new(options.num_inference_steps, packed_h * packed_w);
 
238
  let initial_timestep = timesteps[start_index];
239
  let mut latents =
240
  pack_latents(&scheduler.scale_noise(&image_latents, initial_timestep, &noise)?)?;
241
+ drop(image_latents_packed);
242
+ drop(image_latents);
243
+ drop(noise);
244
 
245
  for step_idx in start_index..timesteps.len() {
246
  let timestep = Tensor::from_vec(
 
255
  ],
256
  1,
257
  )?;
 
258
  let noise_pred = self.transformer.forward(
259
  &latent_model_input,
260
  &img_ids,
 
262
  &text_ids,
263
  &timestep,
264
  )?;
265
+ drop(latent_model_input);
266
+ drop(timestep);
267
  let noise_pred = noise_pred
268
  .narrow(1, 0, latents.dim(1)?)?
269
  .to_dtype(DType::F32)?;
270
+ let next_latents = scheduler.step(&noise_pred, &latents)?;
271
+ drop(noise_pred);
272
+ let previous_latents = std::mem::replace(&mut latents, next_latents);
273
+ drop(previous_latents);
274
  }
275
 
276
  (latents, packed_h, packed_w, size)
 
331
  reference_image: Option<&DynamicImage>,
332
  options: &Flux2InpaintOptions,
333
  ) -> Result<DynamicImage> {
334
+ let _cuda_cleanup = CudaTemporaryMemoryCleanup::new(&self.device);
335
  let (latents, packed_h, packed_w, size) = {
336
  let (rgb, size) = prepare_rgb_image(image, options.max_pixels);
337
  let resized_mask = expand_mask(
 
372
  )?;
373
  }
374
  let condition_latents = condition_latents.to_dtype(transformer_dtype)?;
375
+ let img_ids = Tensor::cat(&[latent_ids, condition_ids], 1)?;
376
 
377
  let mut scheduler =
378
  FlowMatchScheduler::new(options.num_inference_steps, packed_h * packed_w);
 
384
  let initial_timestep = timesteps[start_index];
385
  let mut latents =
386
  pack_latents(&scheduler.scale_noise(&image_latents, initial_timestep, &noise)?)?;
387
+ let keep_mask = ((&latent_mask * -1.0)? + 1.0)?;
388
+ drop(noise);
389
 
390
  for step_idx in start_index..timesteps.len() {
391
  let timestep = Tensor::from_vec(
 
400
  ],
401
  1,
402
  )?;
 
403
  let noise_pred = self.transformer.forward(
404
  &latent_model_input,
405
  &img_ids,
 
407
  &text_ids,
408
  &timestep,
409
  )?;
410
+ drop(latent_model_input);
411
+ drop(timestep);
412
  let noise_pred = noise_pred
413
  .narrow(1, 0, latents.dim(1)?)?
414
  .to_dtype(DType::F32)?;
415
+ let next_latents = scheduler.step(&noise_pred, &latents)?;
416
+ drop(noise_pred);
417
+ let previous_latents = std::mem::replace(&mut latents, next_latents);
418
+ drop(previous_latents);
419
 
420
  let init_latents = if step_idx + 1 < timesteps.len() {
421
  scheduler.scale_noise(
 
426
  } else {
427
  image_latents_packed.clone()
428
  };
429
+ let masked_latents = (keep_mask.broadcast_mul(&init_latents)?
 
430
  + latent_mask.broadcast_mul(&latents)?)?;
431
+ drop(init_latents);
432
+ let previous_latents = std::mem::replace(&mut latents, masked_latents);
433
+ drop(previous_latents);
434
  }
435
 
436
  (latents, packed_h, packed_w, size)
 
476
  }
477
  }
478
 
479
+ struct CudaTemporaryMemoryCleanup<'a> {
480
+ device: &'a Device,
481
+ }
482
+
483
+ impl<'a> CudaTemporaryMemoryCleanup<'a> {
484
+ fn new(device: &'a Device) -> Self {
485
+ Self { device }
486
+ }
487
+ }
488
+
489
+ impl Drop for CudaTemporaryMemoryCleanup<'_> {
490
+ fn drop(&mut self) {
491
+ let _ = release_cuda_temporary_memory(self.device);
492
+ }
493
+ }
494
+
495
+ fn transformer_dtype(device: &Device) -> DType {
496
+ if device.is_cuda() {
497
+ return DType::BF16;
498
+ }
499
+
500
  DType::F32
501
  }
502
 
503
+ fn release_cuda_temporary_memory(device: &Device) -> Result<()> {
504
+ device.synchronize()?;
505
+
506
+ #[cfg(feature = "cuda")]
507
+ if let Ok(cuda_device) = device.as_cuda_device() {
508
+ let stream = cuda_device.cuda_stream();
509
+ let context = stream.context();
510
+ if context.has_async_alloc() {
511
+ context.bind_to_thread()?;
512
+ let pool = unsafe {
513
+ candle_core::cuda::cudarc::driver::result::device::get_mem_pool(
514
+ context.cu_device(),
515
+ )?
516
+ };
517
+ unsafe {
518
+ candle_core::cuda::cudarc::driver::result::mem_pool::trim_to(pool, 0)?;
519
+ }
520
+ }
521
+ }
522
+
523
+ Ok(())
524
+ }
525
+
526
  fn inpaint_crop_bounds(
527
  image: &DynamicImage,
528
  mask: &DynamicImage,
koharu-ml/src/flux2_klein/transformer.rs CHANGED
@@ -1,8 +1,6 @@
1
  use std::path::Path;
2
 
3
- use candle_core::{D, DType, IndexOp, Module, Result, Tensor};
4
- use candle_nn::{LayerNorm, RmsNorm};
5
- use candle_transformers::quantized_nn::{Linear, linear_b};
6
  use candle_transformers::quantized_var_builder::VarBuilder;
7
 
8
  #[derive(Debug, Clone)]
@@ -32,8 +30,97 @@ impl Default for Flux2TransformerConfig {
32
  }
33
  }
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  fn qlinear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
36
- linear_b(in_dim, out_dim, false, vb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  }
38
 
39
  fn layer_norm(dim: usize, vb: &VarBuilder) -> Result<LayerNorm> {
@@ -83,6 +170,18 @@ fn apply_rope(xs: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
83
  fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
84
  let dim = q.dim(D::Minus1)?;
85
  let scale = 1.0 / (dim as f64).sqrt();
 
 
 
 
 
 
 
 
 
 
 
 
86
  if q.device().is_metal() {
87
  return candle_nn::ops::sdpa(q, k, v, None, false, scale as f32, 1.0);
88
  }
@@ -107,6 +206,8 @@ fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor>
107
  let q = apply_rope(q, pe)?.contiguous()?;
108
  let k = apply_rope(k, pe)?.contiguous()?;
109
  let xs = scaled_dot_product_attention(&q, &k, v)?;
 
 
110
  xs.transpose(1, 2)?.flatten_from(2)
111
  }
112
 
@@ -265,6 +366,7 @@ impl SelfAttention {
265
  let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
266
  let q = q.apply(&self.norm.query_norm)?;
267
  let k = k.apply(&self.norm.key_norm)?;
 
268
  Ok((q, k, v))
269
  }
270
  }
@@ -284,7 +386,9 @@ impl Mlp {
284
  }
285
 
286
  fn forward(&self, xs: &Tensor) -> Result<Tensor> {
287
- swiglu(&xs.apply(&self.lin1)?)?.apply(&self.lin2)
 
 
288
  }
289
  }
290
 
@@ -336,8 +440,10 @@ impl DoubleStreamBlock {
336
 
337
  let img_modulated = img_mod1.scale_shift(&img.apply(&self.img_norm1)?)?;
338
  let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
 
339
  let txt_modulated = txt_mod1.scale_shift(&txt.apply(&self.txt_norm1)?)?;
340
  let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
 
341
 
342
  let attn = {
343
  let q = Tensor::cat(&[&txt_q, &img_q], 2)?;
@@ -361,44 +467,31 @@ impl DoubleStreamBlock {
361
  let img_attn = img_attn.apply(&self.img_attn.proj)?;
362
  let txt_attn = txt_attn.apply(&self.txt_attn.proj)?;
363
  drop(attn);
 
 
 
 
 
 
 
364
  drop(img_modulated);
 
 
 
 
 
 
 
 
 
365
  drop(txt_modulated);
366
-
367
- let img = (img + img_mod1.gate(&img_attn)?)?;
368
- drop(img_attn);
369
- let img_mlp = img_mod2
370
- .scale_shift(&img.apply(&self.img_norm2)?)?
371
- .apply_fn(|xs| self.img_mlp.forward(xs))?;
372
- let img = (img + img_mod2.gate(&img_mlp)?)?;
373
- drop(img_mlp);
374
-
375
- let txt = (txt + txt_mod1.gate(&txt_attn)?)?;
376
- drop(txt_attn);
377
- let txt_mlp = txt_mod2
378
- .scale_shift(&txt.apply(&self.txt_norm2)?)?
379
- .apply_fn(|xs| self.txt_mlp.forward(xs))?;
380
- let txt = (txt + txt_mod2.gate(&txt_mlp)?)?;
381
- drop(txt_mlp);
382
 
383
  Ok((img, txt))
384
  }
385
  }
386
 
387
- trait ApplyFn {
388
- fn apply_fn<F>(&self, f: F) -> Result<Tensor>
389
- where
390
- F: FnOnce(&Tensor) -> Result<Tensor>;
391
- }
392
-
393
- impl ApplyFn for Tensor {
394
- fn apply_fn<F>(&self, f: F) -> Result<Tensor>
395
- where
396
- F: FnOnce(&Tensor) -> Result<Tensor>,
397
- {
398
- f(self)
399
- }
400
- }
401
-
402
  #[derive(Debug, Clone)]
403
  struct SingleStreamBlock {
404
  linear1: Linear,
@@ -432,8 +525,11 @@ impl SingleStreamBlock {
432
 
433
  fn forward(&self, xs: &Tensor, mods: &[ModulationOut], pe: &Tensor) -> Result<Tensor> {
434
  let mod_ = &mods[0];
435
- let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?;
 
 
436
  let qkv_mlp = x_mod.apply(&self.linear1)?;
 
437
  let qkv = qkv_mlp.narrow(D::Minus1, 0, 3 * self.hidden_size)?;
438
  let (b, len, _) = qkv.dims3()?;
439
  let qkv = qkv.reshape((b, len, 3, self.num_heads, ()))?;
@@ -441,6 +537,8 @@ impl SingleStreamBlock {
441
  let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
442
  let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
443
  let mlp = qkv_mlp.narrow(D::Minus1, 3 * self.hidden_size, self.mlp_size * 2)?;
 
 
444
  let q = q.apply(&self.norm.query_norm)?;
445
  let k = k.apply(&self.norm.key_norm)?;
446
  let attn = attention(&q, &k, &v, pe)?;
@@ -448,10 +546,13 @@ impl SingleStreamBlock {
448
  drop(k);
449
  drop(v);
450
  let mlp = swiglu(&mlp)?;
451
- let output = Tensor::cat(&[&attn, &mlp], D::Minus1)?.apply(&self.linear2)?;
452
  drop(attn);
453
  drop(mlp);
454
- xs + mod_.gate(&output)?
 
 
 
455
  }
456
  }
457
 
@@ -585,6 +686,7 @@ impl Flux2Transformer {
585
  let dtype = img.dtype();
586
  let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
587
  let pe = self.pe_embedder.forward(&ids)?;
 
588
  let mut img = img.apply(&self.img_in)?;
589
  let mut txt = txt.apply(&self.txt_in)?;
590
  let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
@@ -595,14 +697,22 @@ impl Flux2Transformer {
595
  for block in &self.double_blocks {
596
  (img, txt) = block.forward(&img, &txt, &ds_img_mods, &ds_txt_mods, &pe)?;
597
  }
 
 
598
  let txt_len = txt.dim(1)?;
599
  let img_len = img.dim(1)?;
600
  let mut xs = Tensor::cat(&[&txt, &img], 1)?;
 
 
601
  for block in &self.single_blocks {
602
  xs = block.forward(&xs, &ss_mods, &pe)?;
603
  }
 
 
604
  let img = xs.narrow(1, txt_len, img_len)?;
605
- self.final_layer.forward(&img, &vec_)
 
 
606
  }
607
 
608
  pub fn in_channels(&self) -> usize {
 
1
  use std::path::Path;
2
 
3
+ use candle_core::{D, DType, IndexOp, Module, Result, Tensor, quantized::QMatMul};
 
 
4
  use candle_transformers::quantized_var_builder::VarBuilder;
5
 
6
  #[derive(Debug, Clone)]
 
30
  }
31
  }
32
 
33
+ #[derive(Debug, Clone)]
34
+ struct Linear {
35
+ weight: QMatMul,
36
+ }
37
+
38
+ impl Module for Linear {
39
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
40
+ let dtype = xs.dtype();
41
+ let xs = if should_promote_for_cuda(xs) {
42
+ xs.to_dtype(DType::F32)?
43
+ } else {
44
+ xs.clone()
45
+ };
46
+ let ys = xs.apply(&self.weight)?;
47
+ if ys.dtype() != dtype && matches!(dtype, DType::BF16 | DType::F16) {
48
+ ys.to_dtype(dtype)
49
+ } else {
50
+ Ok(ys)
51
+ }
52
+ }
53
+ }
54
+
55
  fn qlinear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
56
+ let weight = vb.get((out_dim, in_dim), "weight")?;
57
+ Ok(Linear {
58
+ weight: QMatMul::from_arc(weight)?,
59
+ })
60
+ }
61
+
62
+ #[derive(Debug, Clone)]
63
+ struct LayerNorm {
64
+ inner: candle_nn::LayerNorm,
65
+ }
66
+
67
+ impl LayerNorm {
68
+ fn new_no_bias(weight: Tensor, eps: f64) -> Self {
69
+ Self {
70
+ inner: candle_nn::LayerNorm::new_no_bias(weight, eps),
71
+ }
72
+ }
73
+ }
74
+
75
+ impl Module for LayerNorm {
76
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
77
+ let dtype = xs.dtype();
78
+ let xs = if should_promote_for_cuda(xs) {
79
+ xs.to_dtype(DType::F32)?
80
+ } else {
81
+ xs.clone()
82
+ };
83
+ let ys = xs.apply(&self.inner)?;
84
+ if ys.dtype() != dtype && matches!(dtype, DType::BF16 | DType::F16) {
85
+ ys.to_dtype(dtype)
86
+ } else {
87
+ Ok(ys)
88
+ }
89
+ }
90
+ }
91
+
92
+ #[derive(Debug, Clone)]
93
+ struct RmsNorm {
94
+ inner: candle_nn::RmsNorm,
95
+ }
96
+
97
+ impl RmsNorm {
98
+ fn new(weight: Tensor, eps: f64) -> Self {
99
+ Self {
100
+ inner: candle_nn::RmsNorm::new(weight, eps),
101
+ }
102
+ }
103
+ }
104
+
105
+ impl Module for RmsNorm {
106
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
107
+ let dtype = xs.dtype();
108
+ let xs = if should_promote_for_cuda(xs) {
109
+ xs.to_dtype(DType::F32)?
110
+ } else {
111
+ xs.clone()
112
+ };
113
+ let ys = xs.apply(&self.inner)?;
114
+ if ys.dtype() != dtype && matches!(dtype, DType::BF16 | DType::F16) {
115
+ ys.to_dtype(dtype)
116
+ } else {
117
+ Ok(ys)
118
+ }
119
+ }
120
+ }
121
+
122
+ fn should_promote_for_cuda(xs: &Tensor) -> bool {
123
+ xs.device().is_cuda() && matches!(xs.dtype(), DType::BF16 | DType::F16)
124
  }
125
 
126
  fn layer_norm(dim: usize, vb: &VarBuilder) -> Result<LayerNorm> {
 
170
  fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
171
  let dim = q.dim(D::Minus1)?;
172
  let scale = 1.0 / (dim as f64).sqrt();
173
+ #[cfg(feature = "cuda")]
174
+ if q.device().is_cuda() {
175
+ let q = q.transpose(1, 2)?.contiguous()?;
176
+ let k = k.transpose(1, 2)?.contiguous()?;
177
+ let v = v.transpose(1, 2)?.contiguous()?;
178
+ let xs = candle_flash_attn::flash_attn(&q, &k, &v, scale as f32, false)?;
179
+ drop(q);
180
+ drop(k);
181
+ drop(v);
182
+ return xs.transpose(1, 2);
183
+ }
184
+
185
  if q.device().is_metal() {
186
  return candle_nn::ops::sdpa(q, k, v, None, false, scale as f32, 1.0);
187
  }
 
206
  let q = apply_rope(q, pe)?.contiguous()?;
207
  let k = apply_rope(k, pe)?.contiguous()?;
208
  let xs = scaled_dot_product_attention(&q, &k, v)?;
209
+ drop(q);
210
+ drop(k);
211
  xs.transpose(1, 2)?.flatten_from(2)
212
  }
213
 
 
366
  let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
367
  let q = q.apply(&self.norm.query_norm)?;
368
  let k = k.apply(&self.norm.key_norm)?;
369
+ drop(qkv);
370
  Ok((q, k, v))
371
  }
372
  }
 
386
  }
387
 
388
  fn forward(&self, xs: &Tensor) -> Result<Tensor> {
389
+ let xs = xs.apply(&self.lin1)?;
390
+ let xs = swiglu(&xs)?;
391
+ xs.apply(&self.lin2)
392
  }
393
  }
394
 
 
440
 
441
  let img_modulated = img_mod1.scale_shift(&img.apply(&self.img_norm1)?)?;
442
  let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
443
+ drop(img_modulated);
444
  let txt_modulated = txt_mod1.scale_shift(&txt.apply(&self.txt_norm1)?)?;
445
  let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
446
+ drop(txt_modulated);
447
 
448
  let attn = {
449
  let q = Tensor::cat(&[&txt_q, &img_q], 2)?;
 
467
  let img_attn = img_attn.apply(&self.img_attn.proj)?;
468
  let txt_attn = txt_attn.apply(&self.txt_attn.proj)?;
469
  drop(attn);
470
+
471
+ let img_attn = img_mod1.gate(&img_attn)?;
472
+ let img = (img + img_attn)?;
473
+ let img_normed = img.apply(&self.img_norm2)?;
474
+ let img_modulated = img_mod2.scale_shift(&img_normed)?;
475
+ drop(img_normed);
476
+ let img_mlp = self.img_mlp.forward(&img_modulated)?;
477
  drop(img_modulated);
478
+ let img_mlp = img_mod2.gate(&img_mlp)?;
479
+ let img = (img + img_mlp)?;
480
+
481
+ let txt_attn = txt_mod1.gate(&txt_attn)?;
482
+ let txt = (txt + txt_attn)?;
483
+ let txt_normed = txt.apply(&self.txt_norm2)?;
484
+ let txt_modulated = txt_mod2.scale_shift(&txt_normed)?;
485
+ drop(txt_normed);
486
+ let txt_mlp = self.txt_mlp.forward(&txt_modulated)?;
487
  drop(txt_modulated);
488
+ let txt_mlp = txt_mod2.gate(&txt_mlp)?;
489
+ let txt = (txt + txt_mlp)?;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
  Ok((img, txt))
492
  }
493
  }
494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  #[derive(Debug, Clone)]
496
  struct SingleStreamBlock {
497
  linear1: Linear,
 
525
 
526
  fn forward(&self, xs: &Tensor, mods: &[ModulationOut], pe: &Tensor) -> Result<Tensor> {
527
  let mod_ = &mods[0];
528
+ let x_normed = xs.apply(&self.pre_norm)?;
529
+ let x_mod = mod_.scale_shift(&x_normed)?;
530
+ drop(x_normed);
531
  let qkv_mlp = x_mod.apply(&self.linear1)?;
532
+ drop(x_mod);
533
  let qkv = qkv_mlp.narrow(D::Minus1, 0, 3 * self.hidden_size)?;
534
  let (b, len, _) = qkv.dims3()?;
535
  let qkv = qkv.reshape((b, len, 3, self.num_heads, ()))?;
 
537
  let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
538
  let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
539
  let mlp = qkv_mlp.narrow(D::Minus1, 3 * self.hidden_size, self.mlp_size * 2)?;
540
+ drop(qkv_mlp);
541
+ drop(qkv);
542
  let q = q.apply(&self.norm.query_norm)?;
543
  let k = k.apply(&self.norm.key_norm)?;
544
  let attn = attention(&q, &k, &v, pe)?;
 
546
  drop(k);
547
  drop(v);
548
  let mlp = swiglu(&mlp)?;
549
+ let output = Tensor::cat(&[&attn, &mlp], D::Minus1)?;
550
  drop(attn);
551
  drop(mlp);
552
+ let output = output.apply(&self.linear2)?;
553
+ let gated = mod_.gate(&output)?;
554
+ drop(output);
555
+ xs + gated
556
  }
557
  }
558
 
 
686
  let dtype = img.dtype();
687
  let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
688
  let pe = self.pe_embedder.forward(&ids)?;
689
+ drop(ids);
690
  let mut img = img.apply(&self.img_in)?;
691
  let mut txt = txt.apply(&self.txt_in)?;
692
  let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
 
697
  for block in &self.double_blocks {
698
  (img, txt) = block.forward(&img, &txt, &ds_img_mods, &ds_txt_mods, &pe)?;
699
  }
700
+ drop(ds_img_mods);
701
+ drop(ds_txt_mods);
702
  let txt_len = txt.dim(1)?;
703
  let img_len = img.dim(1)?;
704
  let mut xs = Tensor::cat(&[&txt, &img], 1)?;
705
+ drop(txt);
706
+ drop(img);
707
  for block in &self.single_blocks {
708
  xs = block.forward(&xs, &ss_mods, &pe)?;
709
  }
710
+ drop(ss_mods);
711
+ drop(pe);
712
  let img = xs.narrow(1, txt_len, img_len)?;
713
+ let xs = self.final_layer.forward(&img, &vec_)?;
714
+ drop(img);
715
+ Ok(xs)
716
  }
717
 
718
  pub fn in_channels(&self) -> usize {
koharu-ml/src/flux2_klein/vae.rs CHANGED
@@ -1,4 +1,4 @@
1
- use candle_core::{D, Module, Result, Tensor};
2
  use candle_nn::{Conv2d, Conv2dConfig, GroupNorm, VarBuilder, conv2d, group_norm};
3
 
4
  use super::latents::{patchify_latents, unpatchify_latents};
@@ -30,6 +30,15 @@ impl Default for Flux2VaeConfig {
30
  }
31
  }
32
 
 
 
 
 
 
 
 
 
 
33
  fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
34
  let dim = q.dim(D::Minus1)?;
35
  let scale = 1.0 / (dim as f64).sqrt();
@@ -113,10 +122,7 @@ impl ResnetBlock2D {
113
  num_groups: usize,
114
  vb: VarBuilder,
115
  ) -> Result<Self> {
116
- let conv_cfg = Conv2dConfig {
117
- padding: 1,
118
- ..Default::default()
119
- };
120
  let norm1 = group_norm(num_groups, in_channels, 1e-6, vb.pp("norm1"))?;
121
  let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vb.pp("conv1"))?;
122
  let norm2 = group_norm(num_groups, out_channels, 1e-6, vb.pp("norm2"))?;
@@ -126,7 +132,7 @@ impl ResnetBlock2D {
126
  in_channels,
127
  out_channels,
128
  1,
129
- Default::default(),
130
  vb.pp("conv_shortcut"),
131
  )?)
132
  } else {
@@ -165,11 +171,7 @@ struct Downsample2D {
165
 
166
  impl Downsample2D {
167
  fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
168
- let conv_cfg = Conv2dConfig {
169
- stride: 2,
170
- padding: 0,
171
- ..Default::default()
172
- };
173
  let conv = conv2d(channels, channels, 3, conv_cfg, vb.pp("conv"))?;
174
  Ok(Self { conv })
175
  }
@@ -243,10 +245,7 @@ struct Upsample2D {
243
 
244
  impl Upsample2D {
245
  fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
246
- let conv_cfg = Conv2dConfig {
247
- padding: 1,
248
- ..Default::default()
249
- };
250
  let conv = conv2d(channels, channels, 3, conv_cfg, vb.pp("conv"))?;
251
  Ok(Self { conv })
252
  }
@@ -342,10 +341,7 @@ struct Encoder {
342
 
343
  impl Encoder {
344
  fn new(cfg: &Flux2VaeConfig, vb: VarBuilder) -> Result<Self> {
345
- let conv_cfg = Conv2dConfig {
346
- padding: 1,
347
- ..Default::default()
348
- };
349
  let conv_in = conv2d(
350
  cfg.in_channels,
351
  cfg.block_out_channels[0],
@@ -419,10 +415,7 @@ struct Decoder {
419
 
420
  impl Decoder {
421
  fn new(cfg: &Flux2VaeConfig, vb: VarBuilder) -> Result<Self> {
422
- let conv_cfg = Conv2dConfig {
423
- padding: 1,
424
- ..Default::default()
425
- };
426
  let mid_channels = *cfg.decoder_block_out_channels.last().unwrap();
427
  let conv_in = conv2d(
428
  cfg.latent_channels,
@@ -512,14 +505,14 @@ impl Flux2Vae {
512
  2 * cfg.latent_channels,
513
  2 * cfg.latent_channels,
514
  1,
515
- Default::default(),
516
  vb.pp("quant_conv"),
517
  )?;
518
  let post_quant_conv = conv2d(
519
  cfg.latent_channels,
520
  cfg.latent_channels,
521
  1,
522
- Default::default(),
523
  vb.pp("post_quant_conv"),
524
  )?;
525
  let bn_running_mean = vb.get(4 * cfg.latent_channels, "bn.running_mean")?;
 
1
+ use candle_core::{D, Module, Result, Tensor, conv::CudnnFwdAlgo};
2
  use candle_nn::{Conv2d, Conv2dConfig, GroupNorm, VarBuilder, conv2d, group_norm};
3
 
4
  use super::latents::{patchify_latents, unpatchify_latents};
 
30
  }
31
  }
32
 
33
+ fn vae_conv_config(padding: usize, stride: usize) -> Conv2dConfig {
34
+ Conv2dConfig {
35
+ padding,
36
+ stride,
37
+ cudnn_fwd_algo: Some(CudnnFwdAlgo::ImplicitGemm),
38
+ ..Default::default()
39
+ }
40
+ }
41
+
42
  fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
43
  let dim = q.dim(D::Minus1)?;
44
  let scale = 1.0 / (dim as f64).sqrt();
 
122
  num_groups: usize,
123
  vb: VarBuilder,
124
  ) -> Result<Self> {
125
+ let conv_cfg = vae_conv_config(1, 1);
 
 
 
126
  let norm1 = group_norm(num_groups, in_channels, 1e-6, vb.pp("norm1"))?;
127
  let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vb.pp("conv1"))?;
128
  let norm2 = group_norm(num_groups, out_channels, 1e-6, vb.pp("norm2"))?;
 
132
  in_channels,
133
  out_channels,
134
  1,
135
+ vae_conv_config(0, 1),
136
  vb.pp("conv_shortcut"),
137
  )?)
138
  } else {
 
171
 
172
  impl Downsample2D {
173
  fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
174
+ let conv_cfg = vae_conv_config(0, 2);
 
 
 
 
175
  let conv = conv2d(channels, channels, 3, conv_cfg, vb.pp("conv"))?;
176
  Ok(Self { conv })
177
  }
 
245
 
246
  impl Upsample2D {
247
  fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
248
+ let conv_cfg = vae_conv_config(1, 1);
 
 
 
249
  let conv = conv2d(channels, channels, 3, conv_cfg, vb.pp("conv"))?;
250
  Ok(Self { conv })
251
  }
 
341
 
342
  impl Encoder {
343
  fn new(cfg: &Flux2VaeConfig, vb: VarBuilder) -> Result<Self> {
344
+ let conv_cfg = vae_conv_config(1, 1);
 
 
 
345
  let conv_in = conv2d(
346
  cfg.in_channels,
347
  cfg.block_out_channels[0],
 
415
 
416
  impl Decoder {
417
  fn new(cfg: &Flux2VaeConfig, vb: VarBuilder) -> Result<Self> {
418
+ let conv_cfg = vae_conv_config(1, 1);
 
 
 
419
  let mid_channels = *cfg.decoder_block_out_channels.last().unwrap();
420
  let conv_in = conv2d(
421
  cfg.latent_channels,
 
505
  2 * cfg.latent_channels,
506
  2 * cfg.latent_channels,
507
  1,
508
+ vae_conv_config(0, 1),
509
  vb.pp("quant_conv"),
510
  )?;
511
  let post_quant_conv = conv2d(
512
  cfg.latent_channels,
513
  cfg.latent_channels,
514
  1,
515
+ vae_conv_config(0, 1),
516
  vb.pp("post_quant_conv"),
517
  )?;
518
  let bn_running_mean = vb.get(4 * cfg.latent_channels, "bn.running_mean")?;
koharu-runtime/src/cuda.rs CHANGED
@@ -11,6 +11,7 @@ use crate::loader::{add_runtime_search_path, preload_library};
11
  const CUDA_SUCCESS: i32 = 0;
12
  const CUDA_13_0_DRIVER_VERSION: i32 = 13000;
13
  const CUDA_13_1_DRIVER_VERSION: i32 = 13010;
 
14
  const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: i32 = 75;
15
  const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: i32 = 76;
16
  const MIN_COMPUTE_CAPABILITY: (i32, i32) = (7, 5); // Turing (RTX 20xx) and above
@@ -64,6 +65,31 @@ const WHEELS: &[WheelSpec] = &[
64
  windows_dylibs: &["curand64_10.dll"],
65
  linux_dylibs: &["libcurand.so.10"],
66
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ];
68
 
69
  impl CudaDriverVersion {
@@ -379,9 +405,10 @@ impl WheelSpec {
379
  fn source_id() -> Result<String> {
380
  let packages = WHEELS.iter().map(|wheel| wheel.package).collect::<Vec<_>>();
381
  Ok(format!(
382
- "cuda;platform={};wheels={}",
383
  platform_tags()?.join(","),
384
- packages.join(",")
 
385
  ))
386
  }
387
 
@@ -458,6 +485,20 @@ mod tests {
458
  }
459
  }
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  #[test]
462
  fn parses_major_minor_from_driver_version() {
463
  let version = CudaDriverVersion::from_raw(13010);
 
11
  const CUDA_SUCCESS: i32 = 0;
12
  const CUDA_13_0_DRIVER_VERSION: i32 = 13000;
13
  const CUDA_13_1_DRIVER_VERSION: i32 = 13010;
14
+ const CUDA_EXTRACT_REVISION: u32 = 2;
15
  const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: i32 = 75;
16
  const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: i32 = 76;
17
  const MIN_COMPUTE_CAPABILITY: (i32, i32) = (7, 5); // Turing (RTX 20xx) and above
 
65
  windows_dylibs: &["curand64_10.dll"],
66
  linux_dylibs: &["libcurand.so.10"],
67
  },
68
+ WheelSpec {
69
+ package: "nvidia-cudnn-cu13/9.21.0.82",
70
+ windows_dylibs: &[
71
+ "cudnn64_9.dll",
72
+ "cudnn_adv64_9.dll",
73
+ "cudnn_cnn64_9.dll",
74
+ "cudnn_engines_precompiled64_9.dll",
75
+ "cudnn_engines_runtime_compiled64_9.dll",
76
+ "cudnn_engines_tensor_ir64_9.dll",
77
+ "cudnn_graph64_9.dll",
78
+ "cudnn_heuristic64_9.dll",
79
+ "cudnn_ops64_9.dll",
80
+ ],
81
+ linux_dylibs: &[
82
+ "libcudnn.so.9",
83
+ "libcudnn_adv.so.9",
84
+ "libcudnn_cnn.so.9",
85
+ "libcudnn_engines_precompiled.so.9",
86
+ "libcudnn_engines_runtime_compiled.so.9",
87
+ "libcudnn_engines_tensor_ir.so.9",
88
+ "libcudnn_graph.so.9",
89
+ "libcudnn_heuristic.so.9",
90
+ "libcudnn_ops.so.9",
91
+ ],
92
+ },
93
  ];
94
 
95
  impl CudaDriverVersion {
 
405
  fn source_id() -> Result<String> {
406
  let packages = WHEELS.iter().map(|wheel| wheel.package).collect::<Vec<_>>();
407
  Ok(format!(
408
+ "cuda;platform={};wheels={};extract={}",
409
  platform_tags()?.join(","),
410
+ packages.join(","),
411
+ CUDA_EXTRACT_REVISION
412
  ))
413
  }
414
 
 
485
  }
486
  }
487
 
488
+ #[test]
489
+ fn cuda_runtime_includes_cudnn() {
490
+ let wheel = WHEELS
491
+ .iter()
492
+ .find(|wheel| wheel.package.starts_with("nvidia-cudnn-cu13/"))
493
+ .expect("missing cuDNN runtime wheel");
494
+
495
+ #[cfg(target_os = "windows")]
496
+ assert!(wheel.dylibs().contains(&"cudnn64_9.dll"));
497
+
498
+ #[cfg(target_os = "linux")]
499
+ assert!(wheel.dylibs().contains(&"libcudnn.so.9"));
500
+ }
501
+
502
  #[test]
503
  fn parses_major_minor_from_driver_version() {
504
  let version = CudaDriverVersion::from_raw(13010);
koharu/tauri.windows.conf.json CHANGED
@@ -1,5 +1,5 @@
1
- {
2
- "$schema": "../node_modules/@tauri-apps/cli/config.schema.json",
3
  "identifier": "Koharu",
4
  "build": {
5
  "features": [
 
1
+ {
2
+ "$schema": "../node_modules/@tauri-apps/cli/config.schema.json",
3
  "identifier": "Koharu",
4
  "build": {
5
  "features": [