import os from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel from diffusers.image_processor import VaeImageProcessor from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel, CLIPTextConfig, T5Config import torch import gc from PIL import Image from pipelines.models import TextToImageRequest from torch import Generator from time import perf_counter HOME = os.environ["HOME"] os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01" os.environ['PYTHONMALLOC'] = 'malloc' CHECKPOINT = "black-forest-labs/FLUX.1-schnell" DTYPE = torch.bfloat16 NUM_STEPS = 4 def empty_cache(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def load_pipeline() -> FluxPipeline: empty_cache() pipe = FluxPipeline.from_pretrained(CHECKPOINT, torch_dtype=DTYPE) pipe.text_encoder_2.to(memory_format=torch.channels_last) pipe.transformer.to(memory_format=torch.channels_last) pipe.vae.to(memory_format=torch.channels_last) pipe.vae = torch.compile(pipe.vae) pipe._exclude_from_cpu_offload = ["vae"] pipe.enable_sequential_cpu_offload() empty_cache() pipe("dog", guidance_scale=0.0, max_sequence_length=256, num_inference_steps=4) return pipe @torch.inference_mode() def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image: torch.cuda.reset_peak_memory_stats() if request.seed is None: generator = None else: generator = Generator(device="cuda").manual_seed(request.seed) empty_cache() image = _pipeline(prompt=request.prompt, width=request.width, height=request.height, guidance_scale=0.0, generator=generator, output_type="pil", max_sequence_length=256, num_inference_steps=NUM_STEPS).images[0] return image