Update handler.py
Browse files- handler.py +17 -65
handler.py
CHANGED
|
@@ -1,66 +1,18 @@
|
|
| 1 |
-
# https://github.com/sayakpaul/diffusers-torchao
|
| 2 |
-
#8s
|
| 3 |
-
import os
|
| 4 |
-
from typing import Any, Dict
|
| 5 |
-
from PIL import Image
|
| 6 |
-
import torch
|
| 7 |
-
from diffusers import FluxPipeline
|
| 8 |
-
from huggingface_inference_toolkit.logging import logger
|
| 9 |
-
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
|
| 10 |
-
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
|
| 11 |
-
import time
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def __call__(self, data: Dict[str, Any]) -> Image.Image:
|
| 31 |
-
logger.info(f"Received incoming request with {data=}")
|
| 32 |
-
|
| 33 |
-
if "inputs" in data and isinstance(data["inputs"], str):
|
| 34 |
-
prompt = data.pop("inputs")
|
| 35 |
-
elif "prompt" in data and isinstance(data["prompt"], str):
|
| 36 |
-
prompt = data.pop("prompt")
|
| 37 |
-
else:
|
| 38 |
-
raise ValueError(
|
| 39 |
-
"Provided input body must contain either the key `inputs` or `prompt` with the"
|
| 40 |
-
" prompt to use for the image generation, and it needs to be a non-empty string."
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
parameters = data.pop("parameters", {})
|
| 44 |
-
|
| 45 |
-
num_inference_steps = parameters.get("num_inference_steps", 28)
|
| 46 |
-
width = parameters.get("width", 1024)
|
| 47 |
-
height = parameters.get("height", 1024)
|
| 48 |
-
guidance_scale = parameters.get("guidance_scale", 3.5)
|
| 49 |
-
|
| 50 |
-
# seed generator (seed cannot be provided as is but via a generator)
|
| 51 |
-
seed = parameters.get("seed", 0)
|
| 52 |
-
generator = torch.manual_seed(seed)
|
| 53 |
-
start_time = time.time()
|
| 54 |
-
result = self.pipe( # type: ignore
|
| 55 |
-
prompt,
|
| 56 |
-
height=height,
|
| 57 |
-
width=width,
|
| 58 |
-
guidance_scale=guidance_scale,
|
| 59 |
-
num_inference_steps=num_inference_steps,
|
| 60 |
-
generator=generator,
|
| 61 |
-
# output_type="pil" if dist.get_rank() == 0 else "pt",
|
| 62 |
-
).images[0]
|
| 63 |
-
end_time = time.time()
|
| 64 |
-
time_taken = end_time - start_time
|
| 65 |
-
print(f"Time taken: {time_taken:.2f} seconds")
|
| 66 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
+
builtins
|
| 3 |
+
@overload def print(*values: object,
|
| 4 |
+
sep: str | None = " ",
|
| 5 |
+
end: str | None = "\n",
|
| 6 |
+
file: SupportsWrite[str] | None = None,
|
| 7 |
+
flush: Literal[False] = False) -> None
|
| 8 |
+
Prints the values to a stream, or to sys. stdout by default.
|
| 9 |
+
|
| 10 |
+
sep
|
| 11 |
+
string inserted between values, default a space.
|
| 12 |
+
end
|
| 13 |
+
string appended after the last value, default a newline.
|
| 14 |
+
file
|
| 15 |
+
a file-like object (stream); defaults to the current sys. stdout.
|
| 16 |
+
flush
|
| 17 |
+
whether to forcibly flush the stream.
|
| 18 |
+
docs. python. org 的 `print(*values, sep=" ", end="\n", file=None, flush=False)`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|