refoundd commited on
Commit
bef70e4
·
verified ·
1 Parent(s): 1335bc3

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +96 -0
  2. requirements.txt +15 -0
handler.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/sayakpaul/diffusers-torchao
2
+
3
+ import os
4
+ from typing import Any, Dict
5
+ import gc
6
+ from PIL import Image
7
+ import torch
8
+ from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
9
+ from torchao.quantization.quant_api import PerRow
10
+ from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
11
+
12
+ # Set high precision for float32 matrix multiplications.
13
+ # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
14
+ torch.set_float32_matmul_precision("high")
15
+
16
+ import subprocess
17
+ subprocess.run("pip list", shell=True)
18
+
19
+ IS_COMPILE = True
20
+ IS_TURBO = False
21
+ IS_4BIT = False
22
+
23
+ #if IS_COMPILE:
24
+ # import torch._dynamo
25
+ # torch._dynamo.config.suppress_errors = True
26
+
27
+ from huggingface_inference_toolkit.logging import logger
28
+
29
+ def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
30
+ quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "int8dq")
31
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
32
+ pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
33
+ pipe.transformer.fuse_qkv_projections()
34
+ pipe.vae.fuse_qkv_projections()
35
+ pipe.to("cuda")
36
+ return pipe
37
+
38
+ def load_pipeline_opt(repo_id: str, dtype: torch.dtype) -> Any:
39
+ transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype)
40
+ transformer.fuse_qkv_projections()
41
+ quantize_(transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
42
+ transformer.to(memory_format=torch.channels_last)
43
+ transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
44
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype, transformer=transformer).to("cuda")
45
+ pipe.vae.fuse_qkv_projections()
46
+ quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
47
+ pipe.vae.to(memory_format=torch.channels_last)
48
+ pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
49
+ pipe.to("cuda")
50
+ return pipe
51
+
52
+ class EndpointHandler:
53
+ def __init__(self, path=""):
54
+ repo_id = "NoMoreCopyrightOrg/flux-dev-8step" if IS_TURBO else "NoMoreCopyrightOrg/flux-dev"
55
+ dtype = torch.bfloat16
56
+ #dtype = torch.float16 # for older nVidia GPUs
57
+ if IS_COMPILE: self.pipeline = load_pipeline_opt(repo_id, dtype)
58
+ else: self.pipeline = load_pipeline_stable(repo_id, dtype)
59
+ gc.collect()
60
+ torch.cuda.empty_cache()
61
+
62
+ def __call__(self, data: Dict[str, Any]) -> Image.Image:
63
+ logger.info(f"Received incoming request with {data=}")
64
+
65
+ if "inputs" in data and isinstance(data["inputs"], str):
66
+ prompt = data.pop("inputs")
67
+ elif "prompt" in data and isinstance(data["prompt"], str):
68
+ prompt = data.pop("prompt")
69
+ else:
70
+ raise ValueError(
71
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
72
+ " prompt to use for the image generation, and it needs to be a non-empty string."
73
+ )
74
+
75
+ parameters = data.pop("parameters", {})
76
+
77
+ num_inference_steps = parameters.get("num_inference_steps", 8 if IS_TURBO else 28)
78
+ width = parameters.get("width", 1024)
79
+ height = parameters.get("height", 1024)
80
+ guidance_scale = parameters.get("guidance_scale", 3.5)
81
+
82
+ # seed generator (seed cannot be provided as is but via a generator)
83
+ seed = parameters.get("seed", 0)
84
+ generator = torch.manual_seed(seed)
85
+
86
+ return self.pipeline( # type: ignore
87
+ prompt,
88
+ height=height,
89
+ width=width,
90
+ guidance_scale=guidance_scale,
91
+ num_inference_steps=num_inference_steps,
92
+ generator=generator,
93
+ output_type="pil",
94
+ ).images[0]
95
+
96
+
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu126
2
+ torch==2.6.0+cu126
3
+ torchvision
4
+ torchaudio
5
+ huggingface_hub
6
+ torchao==0.9.0
7
+ diffusers==0.32.2
8
+ peft
9
+ transformers
10
+ numpy
11
+ scipy
12
+ Pillow
13
+ sentencepiece
14
+ protobuf
15
+ triton