refoundd commited on
Commit
f6a8db2
·
verified ·
1 Parent(s): 45ccc55

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -41
handler.py CHANGED
@@ -4,28 +4,20 @@ import os
4
  from typing import Any, Dict
5
  from PIL import Image
6
  import torch
7
- from diffusers import FluxPipeline
8
  from huggingface_inference_toolkit.logging import logger
9
- from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
10
- from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
11
- import time
 
 
 
 
12
 
13
  class EndpointHandler:
14
  def __init__(self,path=""):
15
 
16
- self.pipe = FluxPipeline.from_pretrained(
17
- "NoMoreCopyrightOrg/flux-dev",
18
- torch_dtype=torch.bfloat16,
19
- ).to("cuda")
20
- apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
21
- quantize_(self.pipe.text_encoder, float8_weight_only())
22
- quantize_(self.pipe.transformer, float8_dynamic_activation_float8_weight())
23
- self.pipe.transformer = torch.compile(
24
- self.pipe.transformer, mode="max-autotune-no-cudagraphs",
25
- )
26
- self.pipe.vae = torch.compile(
27
- self.pipe.vae, mode="max-autotune-no-cudagraphs",
28
- )
29
 
30
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
31
  logger.info(f"Received incoming request with {data=}")
@@ -40,27 +32,4 @@ class EndpointHandler:
40
  " prompt to use for the image generation, and it needs to be a non-empty string."
41
  )
42
 
43
- parameters = data.pop("parameters", {})
44
-
45
- num_inference_steps = parameters.get("num_inference_steps", 28)
46
- width = parameters.get("width", 1024)
47
- height = parameters.get("height", 1024)
48
- guidance_scale = parameters.get("guidance_scale", 3.5)
49
-
50
- # seed generator (seed cannot be provided as is but via a generator)
51
- seed = parameters.get("seed", 0)
52
- generator = torch.manual_seed(seed)
53
- start_time = time.time()
54
- result = self.pipe( # type: ignore
55
- prompt,
56
- height=height,
57
- width=width,
58
- guidance_scale=guidance_scale,
59
- num_inference_steps=num_inference_steps,
60
- generator=generator,
61
- # output_type="pil" if dist.get_rank() == 0 else "pt",
62
- ).images[0]
63
- end_time = time.time()
64
- time_taken = end_time - start_time
65
- print(f"Time taken: {time_taken:.2f} seconds")
66
- return result
 
4
  from typing import Any, Dict
5
  from PIL import Image
6
  import torch
7
+ import torch.distributed as dist
8
  from huggingface_inference_toolkit.logging import logger
9
+
10
+ dist.init_process_group()
11
+ torch.cuda.set_device(dist.get_rank())
12
+
13
+ from para_attn.context_parallel import init_context_parallel_mesh
14
+ from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
15
+ from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
16
 
17
  class EndpointHandler:
18
  def __init__(self,path=""):
19
 
20
+
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
23
  logger.info(f"Received incoming request with {data=}")
 
32
  " prompt to use for the image generation, and it needs to be a non-empty string."
33
  )
34
 
35
+ return "1"