English
John6666 commited on
Commit
2b7dcd2
·
verified ·
1 Parent(s): e23490b

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +25 -7
handler.py CHANGED
@@ -2,20 +2,24 @@
2
 
3
  import os
4
  from typing import Any, Dict
5
-
6
- from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
7
  from PIL import Image
8
- import torch
9
- from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight
10
  from huggingface_hub import hf_hub_download
11
- import gc
 
 
 
 
 
 
 
12
 
13
  import subprocess
14
  subprocess.run("pip list", shell=True)
15
 
16
  IS_COMPILE = True
17
  IS_TURBO = False
18
- IS_4BIT = True
19
 
20
  #if IS_COMPILE:
21
  # import torch._dynamo
@@ -92,12 +96,26 @@ def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
92
  pipe.to("cuda")
93
  return pipe
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  class EndpointHandler:
96
  def __init__(self, path=""):
97
  repo_id = "NoMoreCopyrightOrg/flux-dev-8step" if IS_TURBO else "NoMoreCopyrightOrg/flux-dev"
98
  dtype = torch.bfloat16
99
  #dtype = torch.float16 # for older nVidia GPUs
100
- if IS_COMPILE: self.pipeline = load_pipeline_compile(repo_id, dtype)
101
  else: self.pipeline = load_pipeline_stable(repo_id, dtype)
102
  gc.collect()
103
  torch.cuda.empty_cache()
 
2
 
3
  import os
4
  from typing import Any, Dict
5
+ import gc
 
6
  from PIL import Image
 
 
7
  from huggingface_hub import hf_hub_download
8
+ import torch
9
+ from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight, float8_dynamic_activation_float8_weight
10
+ from torchao.quantization.quant_api import PerRow
11
+ from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
12
+
13
+ # Set high precision for float32 matrix multiplications.
14
+ # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
15
+ torch.set_float32_matmul_precision("high")
16
 
17
  import subprocess
18
  subprocess.run("pip list", shell=True)
19
 
20
  IS_COMPILE = True
21
  IS_TURBO = False
22
+ IS_4BIT = False
23
 
24
  #if IS_COMPILE:
25
  # import torch._dynamo
 
96
  pipe.to("cuda")
97
  return pipe
98
 
99
+ def load_pipeline_opt(repo_id: str, dtype: torch.dtype) -> Any:
100
+ transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype)
101
+ transformer.fuse_qkv_projections()
102
+ quantize_(transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
103
+ transformer.to(memory_format=torch.channels_last)
104
+ transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
105
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype, transformer=transformer).to("cuda")
106
+ pipe.vae.fuse_qkv_projections()
107
+ quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
108
+ pipe.vae.to(memory_format=torch.channels_last)
109
+ pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
110
+ pipe.to("cuda")
111
+ return pipe
112
+
113
  class EndpointHandler:
114
  def __init__(self, path=""):
115
  repo_id = "NoMoreCopyrightOrg/flux-dev-8step" if IS_TURBO else "NoMoreCopyrightOrg/flux-dev"
116
  dtype = torch.bfloat16
117
  #dtype = torch.float16 # for older nVidia GPUs
118
+ if IS_COMPILE: self.pipeline = load_pipeline_opt(repo_id, dtype)
119
  else: self.pipeline = load_pipeline_stable(repo_id, dtype)
120
  gc.collect()
121
  torch.cuda.empty_cache()