English
John6666 commited on
Commit
f7b87cf
·
verified ·
1 Parent(s): 3fdb494

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -14
handler.py CHANGED
@@ -4,6 +4,7 @@ from typing import Any, Dict
4
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
5
  from PIL import Image
6
  import torch
 
7
 
8
  IS_COMPILE = True
9
 
@@ -11,33 +12,53 @@ if IS_COMPILE:
11
  import torch._dynamo
12
  torch._dynamo.config.suppress_errors = True
13
 
14
- #from huggingface_inference_toolkit.logging import logger
15
 
16
- def compile_pipeline(pipe) -> Any:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  pipe.transformer.to(memory_format=torch.channels_last)
18
- #pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
19
  pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
20
  pipe.vae.to(memory_format=torch.channels_last)
21
- #pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
22
  pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
 
 
 
23
  return pipe
24
 
25
  class EndpointHandler:
26
  def __init__(self, path=""):
27
  repo_id = "camenduru/FLUX.1-dev-diffusers"
28
- #repo_id = "NoMoreCopyright/FLUX.1-dev-test"
29
  dtype = torch.bfloat16
30
- quantization_config = TorchAoConfig("int8dq")
31
- vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
32
- #transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=quantization_config).to("cuda")
33
- self.pipeline = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
34
- self.pipeline.transformer.fuse_qkv_projections()
35
- self.pipeline.vae.fuse_qkv_projections()
36
- if IS_COMPILE: self.pipeline = compile_pipeline(self.pipeline)
37
- self.pipeline.to("cuda")
38
 
39
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
40
- #logger.info(f"Received incoming request with {data=}")
41
 
42
  if "inputs" in data and isinstance(data["inputs"], str):
43
  prompt = data.pop("inputs")
 
4
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
5
  from PIL import Image
6
  import torch
7
+ from torchao.quantization import quantize_, autoquant
8
 
9
  IS_COMPILE = True
10
 
 
12
  import torch._dynamo
13
  torch._dynamo.config.suppress_errors = True
14
 
15
+ from huggingface_inference_toolkit.logging import logger
16
 
17
+ def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
18
+ quantization_config = TorchAoConfig("int8dq")
19
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
20
+ pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
21
+ pipe.transformer.fuse_qkv_projections()
22
+ pipe.vae.fuse_qkv_projections()
23
+ pipe.to("cuda")
24
+ return pipe
25
+
26
+ def load_pipeline_compile(repo_id: str, dtype: torch.dtype) -> Any:
27
+ quantization_config = TorchAoConfig("int8dq")
28
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
29
+ pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
30
+ pipe.transformer.fuse_qkv_projections()
31
+ pipe.vae.fuse_qkv_projections()
32
+ pipe.transformer.to(memory_format=torch.channels_last)
33
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
34
+ pipe.vae.to(memory_format=torch.channels_last)
35
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
36
+ pipe.to("cuda")
37
+ return pipe
38
+
39
+ def load_pipeline_autoquant(repo_id: str, dtype: torch.dtype) -> Any:
40
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
41
+ pipe.transformer.fuse_qkv_projections()
42
+ pipe.vae.fuse_qkv_projections()
43
  pipe.transformer.to(memory_format=torch.channels_last)
 
44
  pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
45
  pipe.vae.to(memory_format=torch.channels_last)
 
46
  pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
47
+ pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
48
+ pipe.vae = autoquant(pipe.vae, error_on_unseen=False)
49
+ pipe.to("cuda")
50
  return pipe
51
 
52
  class EndpointHandler:
53
  def __init__(self, path=""):
54
  repo_id = "camenduru/FLUX.1-dev-diffusers"
 
55
  dtype = torch.bfloat16
56
+ self.pipeline = load_pipeline_autoquant(repo_id, dtype)
57
+ #if IS_COMPILE: self.pipeline = load_pipeline_compile(repo_id, dtype)
58
+ #else: self.pipeline = load_pipeline_stable(repo_id, dtype)
 
 
 
 
 
59
 
60
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
61
+ logger.info(f"Received incoming request with {data=}")
62
 
63
  if "inputs" in data and isinstance(data["inputs"], str):
64
  prompt = data.pop("inputs")