| # import torch_tensorrt | |
| import os | |
| from typing import TypeAlias | |
| import torch | |
| from PIL.Image import Image | |
| from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, AutoencoderTiny | |
| from huggingface_hub.constants import HF_HUB_CACHE | |
| from pipelines.models import TextToImageRequest | |
| from torch import Generator | |
| from torchao.quantization import quantize_, int8_weight_only | |
| from transformers import T5EncoderModel, CLIPTextModel, logging | |
| from functools import partial | |
| my_overhead_compile = partial(torch.compile, mode="reduce-overhead", fullgraph=True) | |
| Pipeline: TypeAlias = FluxPipeline | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.benchmark = True | |
| torch._inductor.config.conv_1x1_as_mm = True | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.epilogue_fusion = False | |
| torch._inductor.config.coordinate_descent_check_all_directions = True | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" | |
| CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo" | |
| REVISION = "5ef0012f11a863e5111ec56540302a023bc8587b" | |
| TinyVAE = "madebyollin/taef1" | |
| TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689" | |
| def load_pipeline() -> Pipeline: | |
| path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer") | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| path, | |
| use_safetensors=False, | |
| local_files_only=True, | |
| torch_dtype=torch.bfloat16) | |
| vae = AutoencoderTiny.from_pretrained( | |
| TinyVAE, | |
| revision=TinyVAE_REV, | |
| local_files_only=True, | |
| torch_dtype=torch.bfloat16) | |
| pipeline = FluxPipeline.from_pretrained( | |
| CHECKPOINT, | |
| revision=REVISION, | |
| transformer=transformer, | |
| vae=vae, | |
| local_files_only=True, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| pipeline.transformer.to(memory_format=torch.channels_last) | |
| pipeline.vae.to(memory_format=torch.channels_last) | |
| quantize_(pipeline.vae, int8_weight_only()) | |
| pipeline.vae = my_overhead_compile(pipeline.vae) | |
| pipeline.to("cuda") | |
| for _ in range(2): | |
| pipeline("cat", num_inference_steps=4) | |
| return pipeline | |
| def infer(request: TextToImageRequest, pipeline: Pipeline, generator: torch.Generator) -> Image: | |
| return pipeline( | |
| request.prompt, | |
| generator=generator, | |
| guidance_scale=0.0, | |
| num_inference_steps=4, | |
| max_sequence_length=256, | |
| height=request.height, | |
| width=request.width, | |
| ).images[0] | |
| if __name__ == "__main__": | |
| from time import perf_counter | |
| PROMPT = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle' | |
| request = TextToImageRequest(prompt=PROMPT, | |
| height=None, | |
| width=None, | |
| seed=666) | |
| start_time = perf_counter() | |
| pipe_ = load_pipeline() | |
| stop_time = perf_counter() | |
| print(f"Pipeline is loaded in {stop_time - start_time}s") | |
| for _ in range(4): | |
| start_time = perf_counter() | |
| infer(request, pipe_) | |
| stop_time = perf_counter() | |
| print(f"Request in {stop_time - start_time}s") | |