import os from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel from diffusers.image_processor import VaeImageProcessor from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel, CLIPTextConfig, T5Config import torch import gc from PIL.Image import Image from pipelines.models import TextToImageRequest from torch import Generator from torchao.quantization import quantize_, int8_weight_only from time import perf_counter HOME = os.environ["HOME"] QUANTIZED_MODEL = ["text_encoder_2", "vae"] os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01" FLUX_CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo" FLUX_CACHE = os.path.join(HOME, ".cache/huggingface/hub/models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b") torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.cuda.set_per_process_memory_fraction(0.99) QUANT_CONFIG = int8_weight_only() DTYPE = torch.bfloat16 NUM_STEPS = 4 PROMPT = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle' def empty_cache(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def quantize(pipe, config): if "text_encoder" in QUANTIZED_MODEL: quantize_(pipe.text_encoder, config) if "text_encoder_2" in QUANTIZED_MODEL: quantize_(pipe.text_encoder_2, config) if "transformer" in QUANTIZED_MODEL: quantize_(pipe.transformer, config, device="cuda") if "vae" in QUANTIZED_MODEL: quantize_(pipe.vae, config) return pipe def load_pipeline() -> FluxPipeline: empty_cache() transformer = FluxTransformer2DModel.from_pretrained(os.path.join(FLUX_CACHE, "transformer"), use_safetensors=False, torch_dtype=DTYPE) pipe = FluxPipeline.from_pretrained(FLUX_CHECKPOINT, transformer=transformer, torch_dtype=DTYPE) pipe.vae.enable_tiling() pipe.vae.enable_slicing() quantize(pipe, QUANT_CONFIG) pipe.to("cuda") request = TextToImageRequest(prompt=PROMPT, height=1024, width=1024, seed=666) infer(request, pipe) # pipe.enable_model_cpu_offload() return pipe def encode_prompt(_pipeline, prompt: str): pipeline = FluxPipeline.from_pipe( _pipeline, transformer=None, vae=None, ).to("cuda") with torch.no_grad(): outputs = pipeline.encode_prompt( prompt=prompt, prompt_2=None, max_sequence_length=256) del pipeline empty_cache() return outputs def infer_latents(_pipeline, prompt_embeds, pooled_prompt_embeds, width: int | None, height: int | None, seed: int | None): pipeline = FluxPipeline.from_pipe( _pipeline, text_encoder=None, text_encoder_2=None, tokenizer=None, tokenizer_2=None, vae=None, ).to("cuda") if seed is None: generator = None else: generator = Generator(pipeline.device).manual_seed(seed) outputs = pipeline( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=4, guidance_scale=0.0, width=width, height=height, generator=generator, output_type="latent", ).images del pipeline empty_cache() return outputs def decode_latents(vae, latents, width, height): vae.to("cuda") vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) width = width or 64 * vae_scale_factor height = height or 64 * vae_scale_factor image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) with torch.no_grad(): latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor image = vae.decode(latents, return_dict=False)[0] return image_processor.postprocess(image, output_type="pil")[0] # def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image: # prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(_pipeline, request.prompt) # latents = infer_latents(_pipeline, prompt_embeds, pooled_prompt_embeds, request.width, request.height, request.seed) # del prompt_embeds # del pooled_prompt_embeds # del text_ids # # _pipeline.transformer.to("cpu") # image = decode_latents(_pipeline.vae, latents, request.width, request.height) # torch.cuda.reset_peak_memory_stats() # return image def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image: if request.seed is None: generator = None else: generator = Generator(device="cuda").manual_seed(request.seed) torch.cuda.reset_peak_memory_stats() image = _pipeline(prompt=request.prompt, width=request.width, height=request.height, guidance_scale=0.0, generator=generator, output_type="pil", max_sequence_length=256, num_inference_steps=NUM_STEPS).images[0] return image if __name__ == "__main__": request = TextToImageRequest(prompt=PROMPT, height=None, width=None, seed=666) start_time = perf_counter() pipe_ = load_pipeline() stop_time = perf_counter() print(f"Pipeline is loaded in {stop_time - start_time}s") for _ in range(4): start_time = perf_counter() infer(request, pipe_) stop_time = perf_counter() print(f"Request in {stop_time - start_time}s") # pipe("cat holding a womboai sign", num_inference_steps=4, guidance_scale=0, generator=torch.Generator(pipe.device).manual_seed(666)).images[0].save("sample.png")