from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel import torch import torch._dynamo import gc from PIL import Image as img from PIL.Image import Image from pipelines.models import TextToImageRequest from torch import Generator import time from diffusers import FluxTransformer2DModel, DiffusionPipeline from torchao.quantization import quantize_, int8_weight_only from diffusers.image_processor import VaeImageProcessor Pipeline = None import os MODEL_ID = "black-forest-labs/FLUX.1-schnell" traced_vae_decode_path = "traced_vae_decode.pt" def empty_cache(): start = time.time() gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() print(f"Flush took: {time.time() - start}") def load_pipeline() -> Pipeline: empty_cache() dtype, device = torch.bfloat16, "cuda" vae = AutoencoderKL.from_pretrained( MODEL_ID, subfolder="vae", torch_dtype=torch.bfloat16 ) quantize_(vae, int8_weight_only()) pipeline = DiffusionPipeline.from_pretrained( MODEL_ID, vae=vae, torch_dtype=dtype, ) pipeline.enable_sequential_cpu_offload() for _ in range(2): empty_cache() pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256) return pipeline def trace_and_save_vae_decoder(vae, latents): # import sys # sys.exit(1) try: traced_vae_decode = torch.jit.trace(vae.decode, (latents, True)) torch.jit.save(traced_vae_decode, traced_vae_decode_path) return traced_vae_decode except Exception as e: print(f"JIT tracing failed: {e}") return vae.decode #Fall back to untraced decoder. def decode_latents_to_image(latents, height: int, width: int, vae): if not height: height = 1024 if not width: width = 1024 if vae.config.block_out_channels: vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) else: vae_scale_factor = 1 image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) # # Try to load the traced model; trace and save if not found # if os.path.exists(traced_vae_decode_path): # try: # traced_vae_decode = torch.jit.load(traced_vae_decode_path) # # print("Loaded traced VAE decoder from file.") # except Exception as e: # # print(f"Error loading traced VAE decoder: {e}. Retracing...") # traced_vae_decode = trace_and_save_vae_decoder(vae, latents) # else: # traced_vae_decode = trace_and_save_vae_decoder(vae, latents) traced_vae_decode = vae.decode with torch.no_grad(): latents = FluxPipeline._unpack_latents(latents.unsqueeze(0), height, width, vae_scale_factor) latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor image = traced_vae_decode(latents, return_dict=False)[0] # Use the traced function decoded_image = image_processor.postprocess(image, output_type="pil")[0] return decoded_image @torch.inference_mode() def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: empty_cache() generator = Generator("cuda").manual_seed(request.seed) latent=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="latent").images[0] return decode_latents_to_image(latent, request.height, request.width, pipeline.vae)