refoundd commited on
Commit
fe14a7c
·
verified ·
1 Parent(s): 99605e9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +67 -17
handler.py CHANGED
@@ -1,18 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
1
 
2
-  builtins
3
- @overload def print(*values: object,
4
- sep: str | None = " ",
5
- end: str | None = "\n",
6
- file: SupportsWrite[str] | None = None,
7
- flush: Literal[False] = False) -> None
8
- Prints the values to a stream, or to sys. stdout by default.
9
-
10
- sep
11
- string inserted between values, default a space.
12
- end
13
- string appended after the last value, default a newline.
14
- file
15
- a file-like object (stream); defaults to the current sys. stdout.
16
- flush
17
- whether to forcibly flush the stream.
18
- docs. python. org 的 `print(*values, sep=" ", end="\n", file=None, flush=False)`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/sayakpaul/diffusers-torchao
2
+ #8s
3
+ import os
4
+ from typing import Any, Dict
5
+ from PIL import Image
6
+ import torch
7
+ from diffusers import FluxPipeline
8
+ from huggingface_inference_toolkit.logging import logger
9
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
10
+ from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
11
+ import time
12
 
13
+ class EndpointHandler:
14
+ def __init__(self,path=""):
15
+
16
+ self.pipe = FluxPipeline.from_pretrained(
17
+ "NoMoreCopyrightOrg/flux-dev",
18
+ torch_dtype=torch.bfloat16,
19
+ ).to("cuda")
20
+ apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
21
+ # quantize_(self.pipe.text_encoder, float8_weight_only())
22
+ # quantize_(self.pipe.transformer, float8_dynamic_activation_float8_weight())
23
+ self.pipe.transformer = torch.compile(
24
+ self.pipe.transformer, mode="max-autotune-no-cudagraphs",
25
+ )
26
+ self.pipe.vae = torch.compile(
27
+ self.pipe.vae, mode="max-autotune-no-cudagraphs",
28
+ )
29
+
30
+ def __call__(self, data: Dict[str, Any]) -> Image.Image:
31
+ logger.info(f"Received incoming request with {data=}")
32
+
33
+ if "inputs" in data and isinstance(data["inputs"], str):
34
+ prompt = data.pop("inputs")
35
+ elif "prompt" in data and isinstance(data["prompt"], str):
36
+ prompt = data.pop("prompt")
37
+ else:
38
+ raise ValueError(
39
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
40
+ " prompt to use for the image generation, and it needs to be a non-empty string."
41
+ )
42
+
43
+ parameters = data.pop("parameters", {})
44
+
45
+ num_inference_steps = parameters.get("num_inference_steps", 28)
46
+ width = parameters.get("width", 1024)
47
+ height = parameters.get("height", 1024)
48
+ guidance_scale = parameters.get("guidance_scale", 3.5)
49
+
50
+ # seed generator (seed cannot be provided as is but via a generator)
51
+ seed = parameters.get("seed", 0)
52
+ generator = torch.manual_seed(seed)
53
+ start_time = time.time()
54
+ result = self.pipe( # type: ignore
55
+ prompt,
56
+ height=height,
57
+ width=width,
58
+ guidance_scale=guidance_scale,
59
+ num_inference_steps=num_inference_steps,
60
+ generator=generator,
61
+ # output_type="pil" if dist.get_rank() == 0 else "pt",
62
+ ).images[0]
63
+ print(result)
64
+ print(f"{result!r}")
65
+ end_time = time.time()
66
+ time_taken = end_time - start_time
67
+ print(f"Time taken: {time_taken:.2f} seconds")
68
+ return result