NoMoreCopyright commited on
Commit
55a0db3
·
verified ·
1 Parent(s): fac23fc

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +20 -12
  2. requirements.txt +4 -1
handler.py CHANGED
@@ -1,27 +1,35 @@
1
  import os
2
  from typing import Any, Dict
3
 
4
- from diffusers import FluxPipeline, FluxTransformer2DModel
5
- from torchao.quantization import int8_weight_only, quantize_
6
  from PIL.Image import Image
7
  import torch
8
 
9
- from huggingface_inference_toolkit.logging import logger
 
 
 
 
 
 
 
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, **kwargs: Any) -> None: # type: ignore
 
13
  repo_id = "camenduru/FLUX.1-dev-diffusers"
14
  dtype = torch.bfloat16
15
- transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype)
16
- quantize_(transformer, int8_weight_only(), device="cuda")
17
- transformer.to(memory_format=torch.channels_last)
18
- transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
19
- self.pipeline = FluxPipeline.from_pretrained(repo_id, transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
20
- self.pipeline.vae.to(memory_format=torch.channels_last)
21
- self.pipeline.vae.decode = torch.compile(self.pipeline.vae.decode, mode="max-autotune", fullgraph=True)
22
-
23
  def __call__(self, data: Dict[str, Any]) -> Image:
24
- logger.info(f"Received incoming request with {data=}")
25
 
26
  if "inputs" in data and isinstance(data["inputs"], str):
27
  prompt = data.pop("inputs")
 
1
  import os
2
  from typing import Any, Dict
3
 
4
+ from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
 
5
  from PIL.Image import Image
6
  import torch
7
 
8
+ import torch._dynamo
9
+ torch._dynamo.config.suppress_errors = True
10
+
11
+ #from huggingface_inference_toolkit.logging import logger
12
+
13
+ def compile_pipeline(pipe):
14
+ pipe.transformer.to(memory_format=torch.channels_last)
15
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
16
+ return pipe
17
 
18
  class EndpointHandler:
19
  def __init__(self, **kwargs: Any) -> None: # type: ignore
20
+ is_compile = False
21
  repo_id = "camenduru/FLUX.1-dev-diffusers"
22
  dtype = torch.bfloat16
23
+ quantization_config = TorchAoConfig("int4dq")
24
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
25
+ #transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype).to("cuda")
26
+ self.pipeline = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
27
+ if is_compile: self.pipeline = compile_pipeline(self.pipeline)
28
+ self.pipeline.to("cuda")
29
+
30
+ @torch.inference_mode()
31
  def __call__(self, data: Dict[str, Any]) -> Image:
32
+ #logger.info(f"Received incoming request with {data=}")
33
 
34
  if "inputs" in data and isinstance(data["inputs"], str):
35
  prompt = data.pop("inputs")
requirements.txt CHANGED
@@ -1,7 +1,10 @@
1
  torch
 
2
  diffusers
3
  peft
4
  accelerate
5
  transformers
6
  numpy
7
- Pillow
 
 
 
1
  torch
2
+ torchvision
3
  diffusers
4
  peft
5
  accelerate
6
  transformers
7
  numpy
8
+ scipy
9
+ Pillow
10
+ triton