John6666 commited on
Commit
f28bd15
·
verified ·
1 Parent(s): 2393f58

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -5
handler.py CHANGED
@@ -4,6 +4,8 @@ from typing import Any, Dict
4
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
5
  from PIL import Image
6
  import torch
 
 
7
 
8
  IS_COMPILE = False
9
 
@@ -14,9 +16,7 @@ if IS_COMPILE:
14
  #from huggingface_inference_toolkit.logging import logger
15
 
16
  def compile_pipeline(pipe) -> Any:
17
- pipe.transformer.to(memory_format=torch.channels_last)
18
  pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
19
- pipe.vae.to(memory_format=torch.channels_last)
20
  pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
21
  return pipe
22
 
@@ -29,11 +29,12 @@ class EndpointHandler:
29
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
30
  #transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=quantization_config).to("cuda")
31
  self.pipeline = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
32
- self.pipeline.transformer.fuse_qkv_projections()
33
- self.pipeline.vae.fuse_qkv_projections()
34
  if IS_COMPILE: self.pipeline = compile_pipeline(self.pipeline)
35
- self.pipeline.to("cuda")
36
 
 
37
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
38
  #logger.info(f"Received incoming request with {data=}")
39
 
 
4
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
5
  from PIL import Image
6
  import torch
7
+ from accelerate import PartialState
8
+ distributed_state = PartialState()
9
 
10
  IS_COMPILE = False
11
 
 
16
  #from huggingface_inference_toolkit.logging import logger
17
 
18
  def compile_pipeline(pipe) -> Any:
 
19
  pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
 
20
  pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
21
  return pipe
22
 
 
29
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
30
  #transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=quantization_config).to("cuda")
31
  self.pipeline = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
32
+ self.pipeline.transformer.fuse_qkv_projections().to(memory_format=torch.channels_last)
33
+ self.pipeline.vae.fuse_qkv_projections().to(memory_format=torch.channels_last)
34
  if IS_COMPILE: self.pipeline = compile_pipeline(self.pipeline)
35
+ self.pipeline.to(distributed_state.device)
36
 
37
+ @torch.inference_mode()
38
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
39
  #logger.info(f"Received incoming request with {data=}")
40