English
John6666 commited on
Commit
817434a
·
verified ·
1 Parent(s): aa10c96

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +43 -261
handler.py CHANGED
@@ -1,267 +1,52 @@
1
- # https://github.com/sayakpaul/diffusers-torchao
2
- # https://github.com/pytorch/ao/releases
3
- # https://developer.nvidia.com/cuda-gpus
4
-
5
  import os
6
  from typing import Any, Dict
7
- import gc
8
- import time
9
  from PIL import Image
10
- from huggingface_hub import hf_hub_download
11
  import torch
12
- from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int4_weight, float8_dynamic_activation_float8_weight, int8_dynamic_activation_int8_weight, int8_weight_only
13
- from torchao.quantization.quant_api import PerRow
14
- from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
15
- from transformers import T5EncoderModel
16
- from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
17
  from huggingface_inference_toolkit.logging import logger
18
-
19
- import subprocess
20
- subprocess.run("pip list", shell=True)
21
-
22
- print("Device Name:", torch.cuda.get_device_name())
23
- print("Device Capability:", torch.cuda.get_device_capability())
24
-
25
- IS_TURBO = False
26
- IS_4BIT = False
27
- IS_PARA = True
28
- IS_MGPU = False
29
- IS_LVRAM = False
30
- IS_COMPILE = True
31
- IS_WARM = True
32
- IS_QUANT = True
33
- IS_AUTOQ = False
34
- IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
35
- IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
36
 
37
  # Set high precision for float32 matrix multiplications.
38
  # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
39
  torch.set_float32_matmul_precision("high")
40
 
41
- if IS_COMPILE:
42
- import torch._dynamo
43
- torch._dynamo.config.suppress_errors = False
44
- #torch._dynamo.config.suppress_errors = True
45
- torch._inductor.config.disable_progress = False
46
- #torch._inductor.config.conv_1x1_as_mm = True
47
- #torch._inductor.config.coordinate_descent_tuning = True
48
- #torch._inductor.config.coordinate_descent_check_all_directions = True
49
- #torch._inductor.config.epilogue_fusion = False
50
-
51
- if IS_MGPU:
52
- import torch.distributed as dist
53
- dist.init_process_group()
54
- torch.cuda.set_device(dist.get_rank())
55
-
56
- def print_vram():
57
- free = torch.cuda.mem_get_info()[0] / (1024 ** 3)
58
- total = torch.cuda.mem_get_info()[1] / (1024 ** 3)
59
- print(f"VRAM: {total - free:.2f}/{total:.2f}GB")
60
-
61
- def pil_to_base64(image: Image.Image, modelname: str, prompt: str, height: int, width: int, steps: int, cfg: float, seed: int) -> str:
62
- import base64
63
- from io import BytesIO
64
- import json
65
- from PIL import PngImagePlugin
66
- metadata = {"prompt": prompt, "num_inference_steps": steps, "guidance_scale": cfg, "seed": seed, "resolution": f"{width} x {height}",
67
- "Model": {"Model": modelname.split("/")[-1]}}
68
- info = PngImagePlugin.PngInfo()
69
- info.add_text("metadata", json.dumps(metadata))
70
- buffered = BytesIO()
71
- image.save(buffered, "PNG", pnginfo=info)
72
- return base64.b64encode(buffered.getvalue()).decode('ascii')
73
-
74
- def load_te2(repo_id: str, dtype: torch.dtype) -> Any:
75
- text_encoder_2 = T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", torch_dtype=dtype)
76
- if IS_4BIT: quantize_(text_encoder_2, int8_dynamic_activation_int4_weight())
77
- else: quantize_(text_encoder_2, int8_dynamic_activation_int8_weight())
78
- return text_encoder_2
79
-
80
- def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
81
- quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "float8dq" if IS_CC90 else "int8wo")
82
- vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
83
- pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=quantization_config)
84
- pipe.transformer.fuse_qkv_projections()
85
- pipe.vae.fuse_qkv_projections()
86
- return pipe
87
-
88
- def load_pipeline_lowvram(repo_id: str, dtype: torch.dtype) -> Any:
89
- int4_config = TorchAoConfig("int4dq")
90
- float8_config = TorchAoConfig("float8dq")
91
- vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
92
- transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=float8_config)
93
- pipe = FluxPipeline.from_pretrained(repo_id, vae=None, transformer=None, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=int4_config)
94
- pipe.transformer = transformer
95
- pipe.vae = vae
96
- #pipe.transformer.fuse_qkv_projections()
97
- #pipe.vae.fuse_qkv_projections()
98
- pipe.to("cuda")
99
- return pipe
100
-
101
- def load_pipeline_compile(repo_id: str, dtype: torch.dtype) -> Any:
102
- quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "float8dq" if IS_CC90 else "int8wo")
103
- vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
104
- pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=quantization_config)
105
- pipe.transformer.fuse_qkv_projections()
106
- pipe.vae.fuse_qkv_projections()
107
- pipe.transformer.to(memory_format=torch.channels_last)
108
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
109
- pipe.vae.to(memory_format=torch.channels_last)
110
- pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
111
- return pipe
112
-
113
- def load_pipeline_autoquant(repo_id: str, dtype: torch.dtype) -> Any:
114
- pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype)
115
- pipe.transformer.fuse_qkv_projections()
116
- pipe.vae.fuse_qkv_projections()
117
- pipe.transformer.to(memory_format=torch.channels_last)
118
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
119
- pipe.vae.to(memory_format=torch.channels_last)
120
- pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
121
- pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
122
- pipe.vae = autoquant(pipe.vae, error_on_unseen=False)
123
- return pipe
124
-
125
- def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
126
- pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype)
127
- pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
128
- pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
129
- pipe.fuse_lora()
130
- pipe.unload_lora_weights()
131
- pipe.transformer.fuse_qkv_projections()
132
- pipe.vae.fuse_qkv_projections()
133
- weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
134
- quantize_(pipe.transformer, weight, device="cuda")
135
- quantize_(pipe.vae, weight, device="cuda")
136
- return pipe
137
-
138
- def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
139
- pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype)
140
- pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
141
- pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
142
- pipe.fuse_lora()
143
- pipe.unload_lora_weights()
144
- pipe.transformer.fuse_qkv_projections()
145
- pipe.vae.fuse_qkv_projections()
146
- weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
147
- quantize_(pipe.transformer, weight, device="cuda")
148
- quantize_(pipe.vae, weight, device="cuda")
149
- pipe.transformer.to(memory_format=torch.channels_last)
150
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
151
- pipe.vae.to(memory_format=torch.channels_last)
152
- pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
153
- return pipe
154
-
155
- def load_pipeline_opt(repo_id: str, dtype: torch.dtype) -> Any:
156
- quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "float8dq" if IS_CC90 else "int8wo")
157
- weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
158
- transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype)
159
- #transformer.fuse_qkv_projections()
160
- if IS_CC90: quantize_(transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
161
- elif IS_CC89: quantize_(transformer, float8_dynamic_activation_float8_weight(), device="cuda")
162
- else: quantize_(transformer, weight, device="cuda")
163
- transformer.to(memory_format=torch.channels_last)
164
- #transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
165
- transformer = torch.compile(transformer, mode="max-autotune-no-cudagraphs")
166
- vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
167
- #vae.fuse_qkv_projections()
168
- if IS_CC90: quantize_(vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
169
- elif IS_CC89: quantize_(vae, float8_dynamic_activation_float8_weight(), device="cuda")
170
- else: quantize_(vae, weight, device="cuda")
171
- vae.to(memory_format=torch.channels_last)
172
- #vae = torch.compile(vae, mode="max-autotune", fullgraph=True)
173
- vae = torch.compile(vae, mode="max-autotune-no-cudagraphs")
174
- pipe = FluxPipeline.from_pretrained(repo_id, transformer=None, vae=None, torch_dtype=dtype, quantization_config=quantization_config)
175
- quantize_(pipe.text_encoder, int8_dynamic_activation_int8_weight())
176
- quantize_(pipe.text_encoder_2, int8_dynamic_activation_int8_weight())
177
- pipe.transformer = transformer
178
- pipe.vae = vae
179
- return pipe
180
-
181
- def load_pipeline_para(repo_id: str, dtype: torch.dtype) -> Any:
182
- weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
183
- pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
184
- if IS_PARA:
185
- if IS_MGPU:
186
- from para_attn.context_parallel import init_context_parallel_mesh
187
- from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
188
- from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
189
- mesh = init_context_parallel_mesh(pipe.device.type, max_ring_dim_size=2)
190
- parallelize_pipe(pipe, mesh=mesh)
191
- parallelize_vae(pipe.vae, mesh=mesh._flatten())
192
- apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
193
- quantize_(pipe.text_encoder, int8_dynamic_activation_int8_weight())
194
- quantize_(pipe.text_encoder_2, int8_dynamic_activation_int8_weight())
195
- if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
196
- elif IS_CC89: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(), device="cuda")
197
- else: quantize_(pipe.transformer, weight, device="cuda")
198
- #pipe.transformer.fuse_qkv_projections()
199
- pipe.transformer.to(memory_format=torch.channels_last)
200
- #pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
201
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
202
- if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
203
- elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
204
- else: quantize_(pipe.vae, weight, device="cuda")
205
- #pipe.vae.fuse_qkv_projections()
206
- pipe.vae.to(memory_format=torch.channels_last)
207
- #pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
208
- pipe.vae = torch.compile(pipe.vae, mode="max-autotune-no-cudagraphs")
209
- return pipe
210
-
211
- def load_pipeline_fast(repo_id: str, dtype: torch.dtype) -> Any:
212
- pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
213
- pipe.enable_vae_slicing()
214
- pipe.enable_vae_tiling()
215
- pipe.transformer.fuse_qkv_projections()
216
- pipe.vae.fuse_qkv_projections()
217
- pipe.transformer.to(memory_format=torch.channels_last)
218
- pipe.vae.to(memory_format=torch.channels_last)
219
- if IS_QUANT and not IS_AUTOQ:
220
- quantize_(pipe.text_encoder, int8_weight_only())
221
- quantize_(pipe.text_encoder_2, int8_weight_only())
222
- if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
223
- elif IS_CC89: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(), device="cuda")
224
- else: quantize_(pipe.vae, int8_weight_only())
225
- if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
226
- elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
227
- else: quantize_(pipe.vae, int8_weight_only())
228
- return pipe
229
 
230
  class EndpointHandler:
231
  def __init__(self, path=""):
232
- repo_id = "NoMoreCopyrightOrg/flux-dev-8step" if IS_TURBO else "NoMoreCopyrightOrg/flux-dev"
233
- self.repo_id = repo_id
234
- dtype = torch.bfloat16
235
- #dtype = torch.float16 # for older nVidia GPUs
236
- print_vram()
237
- print("Loading pipeline...")
238
- if IS_AUTOQ: self.pipeline = load_pipeline_fast(repo_id, dtype)
239
- elif IS_COMPILE: self.pipeline = load_pipeline_fast(repo_id, dtype)
240
- elif IS_LVRAM and IS_CC89: self.pipeline = load_pipeline_lowvram(repo_id, dtype)
241
- else: self.pipeline = load_pipeline_stable(repo_id, dtype)
242
- self.pipeline.enable_vae_slicing()
243
- self.pipeline.enable_vae_tiling()
244
- self.pipeline.to("cuda")
245
- if IS_PARA: apply_cache_on_pipe(self.pipeline, residual_diff_threshold=0.12)
246
- if IS_COMPILE:
247
- print("Compiling pipeline...")
248
- self.pipeline.transformer = torch.compile(self.pipeline.transformer, mode="max-autotune-no-cudagraphs")
249
- self.pipeline.vae = torch.compile(self.pipeline.vae, mode="max-autotune-no-cudagraphs")
250
- if IS_AUTOQ:
251
- print("Running autoquant...")
252
- self.pipeline.transformer = autoquant(self.pipeline.transformer, error_on_unseen=False)
253
- self.pipeline.vae = autoquant(self.pipeline.vae, error_on_unseen=False)
254
- gc.collect()
255
- torch.cuda.empty_cache()
256
- print_vram()
257
- if IS_WARM:
258
- print("Warming pipeline...")
259
- start = time.time()
260
- self.pipeline("Hello world!", output_type="pil")
261
- end = time.time()
262
- print(f'Compiled in {end - start:.3f} sec.')
263
-
264
- def __call__(self, data: Dict[str, Any]) -> str:
265
  logger.info(f"Received incoming request with {data=}")
266
 
267
  if "inputs" in data and isinstance(data["inputs"], str):
@@ -276,7 +61,7 @@ class EndpointHandler:
276
 
277
  parameters = data.pop("parameters", {})
278
 
279
- num_inference_steps = parameters.get("num_inference_steps", 8 if IS_TURBO else 28)
280
  width = parameters.get("width", 1024)
281
  height = parameters.get("height", 1024)
282
  guidance_scale = parameters.get("guidance_scale", 3.5)
@@ -284,20 +69,17 @@ class EndpointHandler:
284
  # seed generator (seed cannot be provided as is but via a generator)
285
  seed = parameters.get("seed", 0)
286
  generator = torch.manual_seed(seed)
287
-
288
- start = time.time()
289
- image = self.pipeline( # type: ignore
290
  prompt,
291
  height=height,
292
  width=width,
293
  guidance_scale=guidance_scale,
294
  num_inference_steps=num_inference_steps,
295
  generator=generator,
296
- output_type="pil",
297
  ).images[0]
298
- end = time.time()
299
- print(f'Elapsed {end - start:.3f} sec. / prompt:"{prompt}" / size:{width}x{height} / steps:{num_inference_steps} / guidance scale:{guidance_scale} / seed:{seed}')
300
-
301
- return pil_to_base64(image, self.repo_id, prompt, height, width, num_inference_steps, guidance_scale, seed)
302
-
303
 
 
 
 
 
 
 
1
  import os
2
  from typing import Any, Dict
 
 
3
  from PIL import Image
 
4
  import torch
5
+ from diffusers import FluxPipeline
 
 
 
 
6
  from huggingface_inference_toolkit.logging import logger
7
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
8
+ from torchao.quantization import autoquant
9
+ import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Set high precision for float32 matrix multiplications.
12
  # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
13
  torch.set_float32_matmul_precision("high")
14
 
15
+ import torch._dynamo
16
+ torch._dynamo.config.suppress_errors = False # for debugging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class EndpointHandler:
19
  def __init__(self, path=""):
20
+ self.pipe = FluxPipeline.from_pretrained(
21
+ "NoMoreCopyrightOrg/flux-dev",
22
+ torch_dtype=torch.bfloat16,
23
+ ).to("cuda")
24
+ self.pipe.enable_vae_slicing()
25
+ self.pipe.enable_vae_tiling()
26
+ self.pipe.transformer.fuse_qkv_projections()
27
+ self.pipe.vae.fuse_qkv_projections()
28
+ self.pipe.transformer.to(memory_format=torch.channels_last)
29
+ self.pipe.vae.to(memory_format=torch.channels_last)
30
+ apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
31
+ self.pipe.transformer = torch.compile(
32
+ self.pipe.transformer, mode="max-autotune-no-cudagraphs",
33
+ )
34
+ self.pipe.vae = torch.compile(
35
+ self.pipe.vae, mode="max-autotune-no-cudagraphs",
36
+ )
37
+ self.pipe.transformer = autoquant(self.pipe.transformer, error_on_unseen=False)
38
+ self.pipe.vae = autoquant(self.pipe.vae, error_on_unseen=False)
39
+ self.pipe.text_encoder = autoquant(self.pipe.text_encoder, error_on_unseen=False)
40
+ self.pipe.text_encoder_2 = autoquant(self.pipe.text_encoder_2, error_on_unseen=False)
41
+
42
+ start_time = time.time()
43
+ print("Start warming-up pipeline")
44
+ self.pipe("Hello world!") # Warm-up for compiling
45
+ end_time = time.time()
46
+ time_taken = end_time - start_time
47
+ print(f"Time taken: {time_taken:.2f} seconds")
48
+
49
+ def __call__(self, data: Dict[str, Any]) -> Image.Image:
 
 
 
50
  logger.info(f"Received incoming request with {data=}")
51
 
52
  if "inputs" in data and isinstance(data["inputs"], str):
 
61
 
62
  parameters = data.pop("parameters", {})
63
 
64
+ num_inference_steps = parameters.get("num_inference_steps", 28)
65
  width = parameters.get("width", 1024)
66
  height = parameters.get("height", 1024)
67
  guidance_scale = parameters.get("guidance_scale", 3.5)
 
69
  # seed generator (seed cannot be provided as is but via a generator)
70
  seed = parameters.get("seed", 0)
71
  generator = torch.manual_seed(seed)
72
+ start_time = time.time()
73
+ result = self.pipe( # type: ignore
 
74
  prompt,
75
  height=height,
76
  width=width,
77
  guidance_scale=guidance_scale,
78
  num_inference_steps=num_inference_steps,
79
  generator=generator,
 
80
  ).images[0]
81
+ end_time = time.time()
82
+ time_taken = end_time - start_time
83
+ print(f"Time taken: {time_taken:.2f} seconds")
 
 
84
 
85
+ return result