| import os |
| from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel |
| from diffusers.image_processor import VaeImageProcessor |
| from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel, CLIPTextConfig, T5Config |
| import torch |
| import gc |
| from PIL.Image import Image |
| from pipelines.models import TextToImageRequest |
| from torch import Generator |
| from torchao.quantization import quantize_, int8_weight_only |
| from time import perf_counter |
|
|
|
|
| HOME = os.environ["HOME"] |
| QUANTIZED_MODEL = ["vae"] |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01" |
| FLUX_CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo" |
| FLUX_CACHE = os.path.join(HOME, ".cache/huggingface/hub/models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b") |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.cuda.set_per_process_memory_fraction(0.99) |
|
|
| QUANT_CONFIG = int8_weight_only() |
| DTYPE = torch.bfloat16 |
| NUM_STEPS = 4 |
| PROMPT = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle' |
|
|
|
|
| def empty_cache(): |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.cuda.reset_max_memory_allocated() |
| torch.cuda.reset_peak_memory_stats() |
|
|
|
|
| def quantize(pipe, config): |
| if "text_encoder" in QUANTIZED_MODEL: |
| quantize_(pipe.text_encoder, config) |
| if "text_encoder_2" in QUANTIZED_MODEL: |
| quantize_(pipe.text_encoder_2, config) |
| if "transformer" in QUANTIZED_MODEL: |
| quantize_(pipe.transformer, config, device="cuda") |
| if "vae" in QUANTIZED_MODEL: |
| quantize_(pipe.vae, config) |
| return pipe |
|
|
|
|
| def load_pipeline() -> FluxPipeline: |
| empty_cache() |
| transformer = FluxTransformer2DModel.from_pretrained(os.path.join(FLUX_CACHE, "transformer"), use_safetensors=False, torch_dtype=DTYPE) |
| pipe = FluxPipeline.from_pretrained(FLUX_CHECKPOINT, |
| transformer=transformer, |
| torch_dtype=DTYPE) |
| pipe.vae.enable_tiling() |
| pipe.vae.enable_slicing() |
| quantize(pipe, QUANT_CONFIG) |
| pipe.vae = torch.compile(pipe.vae) |
| for _ in range(4): |
| request = TextToImageRequest(prompt=PROMPT, height=1024, width=1024, seed=666) |
| infer(request, pipe) |
| return pipe |
|
|
|
|
| def encode_prompt(_pipeline, prompt: str): |
| pipeline = FluxPipeline.from_pipe( |
| _pipeline, |
| transformer=None, |
| vae=None, |
| ).to("cuda") |
| with torch.inference_mode(): |
| outputs = pipeline.encode_prompt( |
| prompt=prompt, |
| prompt_2=None, |
| max_sequence_length=256) |
| |
| return outputs |
|
|
|
|
| def infer_latents(_pipeline, prompt_embeds, pooled_prompt_embeds, width: int | None, height: int | None, seed: int | None): |
| pipeline = FluxPipeline.from_pipe( |
| _pipeline, |
| text_encoder=None, |
| text_encoder_2=None, |
| tokenizer=None, |
| tokenizer_2=None, |
| vae=None, |
| ).to("cuda") |
|
|
| if seed is None: |
| generator = None |
| else: |
| generator = Generator(pipeline.device).manual_seed(seed) |
| outputs = pipeline( |
| prompt_embeds=prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| num_inference_steps=4, |
| guidance_scale=0.0, |
| width=width, |
| height=height, |
| generator=generator, |
| output_type="latent", |
| ).images |
| |
| return outputs |
|
|
|
|
| def decode_latents(vae, latents, width, height): |
| vae.to("cuda") |
| vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) |
| width = width or 64 * vae_scale_factor |
| height = height or 64 * vae_scale_factor |
| image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) |
| |
| latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) |
| latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor |
| with torch.inference_mode(): |
| image = vae.decode(latents, return_dict=False)[0] |
| return image_processor.postprocess(image, output_type="pil")[0] |
|
|
|
|
| |
| def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image: |
| prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(_pipeline, request.prompt) |
| |
| _pipeline.text_encoder.to("cpu") |
| latents = infer_latents(_pipeline, prompt_embeds, pooled_prompt_embeds, request.width, request.height, request.seed) |
| del prompt_embeds |
| del pooled_prompt_embeds |
| del text_ids |
| _pipeline.transformer.single_transformer_blocks.to("cpu") |
| _pipeline.transformer.transformer_blocks.to("cpu") |
| image = decode_latents(_pipeline.vae, latents, request.width, request.height) |
| torch.cuda.reset_peak_memory_stats() |
| _pipeline.vae.to("cpu") |
| return image |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if __name__ == "__main__": |
| request = TextToImageRequest(prompt=PROMPT, |
| height=None, |
| width=None, |
| seed=666) |
| start_time = perf_counter() |
| pipe_ = load_pipeline() |
| stop_time = perf_counter() |
| print(f"Pipeline is loaded in {stop_time - start_time}s") |
| for _ in range(4): |
| start_time = perf_counter() |
| infer(request, pipe_) |
| stop_time = perf_counter() |
| print(f"Request in {stop_time - start_time}s") |
|
|
| |
|
|