manbeast3b commited on
Commit
38fc09e
·
verified ·
1 Parent(s): 796bc4c

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +38 -2
src/pipeline.py CHANGED
@@ -14,6 +14,12 @@ import time
14
  from diffusers import FluxTransformer2DModel, DiffusionPipeline
15
  # from torchao.quantization import quantize_,int8_weight_only
16
  import os
 
 
 
 
 
 
17
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
18
  Pipeline = None
19
 
@@ -44,8 +50,38 @@ def load_pipeline() -> Pipeline:
44
  torch.backends.cuda.matmul.allow_tf32 = True
45
  torch.cuda.set_per_process_memory_fraction(0.99)
46
  pipeline.text_encoder.to(memory_format=torch.channels_last)
47
- pipeline.transformer.to(memory_format=torch.channels_last)
48
- quantize_dynamic(pipeline.transformer, dtype=torch.float8_e5m2fnuz, inplace=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  pipeline.vae.to(memory_format=torch.channels_last)
51
  pipeline.vae = torch.compile(pipeline.vae)
 
14
  from diffusers import FluxTransformer2DModel, DiffusionPipeline
15
  # from torchao.quantization import quantize_,int8_weight_only
16
  import os
17
+
18
+ from torch.ao.quantization import prepare, convert
19
+ from torch.ao.quantization import QConfig
20
+ from torch.ao.quantization.observer import MinMaxObserver
21
+ from torch.ao.quantization.quantize import quantize_dynamic
22
+
23
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
24
  Pipeline = None
25
 
 
50
  torch.backends.cuda.matmul.allow_tf32 = True
51
  torch.cuda.set_per_process_memory_fraction(0.99)
52
  pipeline.text_encoder.to(memory_format=torch.channels_last)
53
+ # pipeline.transformer.to(memory_format=torch.channels_last)
54
+ # quantize_dynamic(pipeline.transformer, dtype=torch.float8_e5m2fnuz, inplace=True)
55
+
56
+
57
+ # Define a custom qconfig for float8_e5m2fnuz
58
+ float8_observer = MinMaxObserver.with_args(dtype=torch.float8_e5m2fnuz)
59
+ custom_qconfig = QConfig(
60
+ activation=float8_observer,
61
+ weight=float8_observer
62
+ )
63
+ qconfig_spec = {
64
+ "linear": custom_qconfig,
65
+ "linear_1": custom_qconfig,
66
+ "linear_2": custom_qconfig,
67
+ "to_q": custom_qconfig,
68
+ "to_k": custom_qconfig,
69
+ "to_v": custom_qconfig,
70
+ "add_k_proj": custom_qconfig,
71
+ "add_v_proj": custom_qconfig,
72
+ "add_q_proj": custom_qconfig,
73
+ "proj": custom_qconfig,
74
+ "proj_mlp": custom_qconfig,
75
+ "proj_out": custom_qconfig
76
+ }
77
+
78
+ # Apply dynamic quantization to Transformer
79
+ pipeline.transformer = quantize_dynamic(
80
+ pipeline.transformer,
81
+ qconfig_spec=qconfig_spec, # Apply qconfig only to transformer layers
82
+ dtype=torch.float8_e5m2fnuz,
83
+ inplace=True,
84
+ )
85
 
86
  pipeline.vae.to(memory_format=torch.channels_last)
87
  pipeline.vae = torch.compile(pipeline.vae)