refoundd commited on
Commit
99605e9
·
verified ·
1 Parent(s): fc40e9e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -65
handler.py CHANGED
@@ -1,66 +1,18 @@
1
- # https://github.com/sayakpaul/diffusers-torchao
2
- #8s
3
- import os
4
- from typing import Any, Dict
5
- from PIL import Image
6
- import torch
7
- from diffusers import FluxPipeline
8
- from huggingface_inference_toolkit.logging import logger
9
- from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
10
- from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
11
- import time
12
 
13
- class EndpointHandler:
14
- def __init__(self,path=""):
15
-
16
- self.pipe = FluxPipeline.from_pretrained(
17
- "NoMoreCopyrightOrg/flux-dev",
18
- torch_dtype=torch.bfloat16,
19
- ).to("cuda")
20
- apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
21
- # quantize_(self.pipe.text_encoder, float8_weight_only())
22
- # quantize_(self.pipe.transformer, float8_dynamic_activation_float8_weight())
23
- self.pipe.transformer = torch.compile(
24
- self.pipe.transformer, mode="max-autotune-no-cudagraphs",
25
- )
26
- self.pipe.vae = torch.compile(
27
- self.pipe.vae, mode="max-autotune-no-cudagraphs",
28
- )
29
-
30
- def __call__(self, data: Dict[str, Any]) -> Image.Image:
31
- logger.info(f"Received incoming request with {data=}")
32
-
33
- if "inputs" in data and isinstance(data["inputs"], str):
34
- prompt = data.pop("inputs")
35
- elif "prompt" in data and isinstance(data["prompt"], str):
36
- prompt = data.pop("prompt")
37
- else:
38
- raise ValueError(
39
- "Provided input body must contain either the key `inputs` or `prompt` with the"
40
- " prompt to use for the image generation, and it needs to be a non-empty string."
41
- )
42
-
43
- parameters = data.pop("parameters", {})
44
-
45
- num_inference_steps = parameters.get("num_inference_steps", 28)
46
- width = parameters.get("width", 1024)
47
- height = parameters.get("height", 1024)
48
- guidance_scale = parameters.get("guidance_scale", 3.5)
49
-
50
- # seed generator (seed cannot be provided as is but via a generator)
51
- seed = parameters.get("seed", 0)
52
- generator = torch.manual_seed(seed)
53
- start_time = time.time()
54
- result = self.pipe( # type: ignore
55
- prompt,
56
- height=height,
57
- width=width,
58
- guidance_scale=guidance_scale,
59
- num_inference_steps=num_inference_steps,
60
- generator=generator,
61
- # output_type="pil" if dist.get_rank() == 0 else "pt",
62
- ).images[0]
63
- end_time = time.time()
64
- time_taken = end_time - start_time
65
- print(f"Time taken: {time_taken:.2f} seconds")
66
- return result
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+  builtins
3
+ @overload def print(*values: object,
4
+ sep: str | None = " ",
5
+ end: str | None = "\n",
6
+ file: SupportsWrite[str] | None = None,
7
+ flush: Literal[False] = False) -> None
8
+ Prints the values to a stream, or to sys. stdout by default.
9
+
10
+ sep
11
+ string inserted between values, default a space.
12
+ end
13
+ string appended after the last value, default a newline.
14
+ file
15
+ a file-like object (stream); defaults to the current sys. stdout.
16
+ flush
17
+ whether to forcibly flush the stream.
18
+ docs. python. org 的 `print(*values, sep=" ", end="\n", file=None, flush=False)`