manbeast3b commited on
Commit
58b1cf8
·
verified ·
1 Parent(s): fd77cf5

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +9 -0
src/pipeline.py CHANGED
@@ -17,6 +17,15 @@ import os
17
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
18
  Pipeline = None
19
 
 
 
 
 
 
 
 
 
 
20
  ckpt_id = "black-forest-labs/FLUX.1-schnell"
21
  def empty_cache():
22
  start = time.time()
 
17
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
18
  Pipeline = None
19
 
20
+ # Define the quantization config
21
+ nf4_config = BitsAndBytesConfig(
22
+ load_in_4bit=True,
23
+ bnb_4bit_quant_type="nf4",
24
+ bnb_4bit_use_double_quant=True,
25
+ bnb_4bit_compute_dtype=torch.bfloat16
26
+ )
27
+
28
+
29
  ckpt_id = "black-forest-labs/FLUX.1-schnell"
30
  def empty_cache():
31
  start = time.time()