import os from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel from diffusers.image_processor import VaeImageProcessor from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel, CLIPTextConfig, T5Config import torch import gc from PIL import Image from pipelines.models import TextToImageRequest from torch import Generator from time import perf_counter os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" class EightQuantize: def __init__(self, bits=8): self.bits = bits self.qmax = (1 << bits) - 1 def __call__(self, x): scale = x.max() / self.qmax x_quant = torch.clip(torch.round(x / scale), 0, self.qmax) return x_quant * scale CHECKPOINT = "black-forest-labs/FLUX.1-schnell" DTYPE = torch.bfloat16 NUM_STEPS = 4 def empty_cache(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def load_pipeline() -> FluxPipeline: empty_cache() is_quantize = 0 _pipe = None pipe = FluxPipeline.from_pretrained(CHECKPOINT, torch_dtype=DTYPE) pipe.text_encoder_2.to(memory_format=torch.channels_last) pipe.transformer.to(memory_format=torch.channels_last) pipe.vae.to(memory_format=torch.channels_last) pipe.vae = torch.compile(pipe.vae) pipe._exclude_from_cpu_offload = ["vae"] try: if is_quantize: quantizer = EightQuantize() with torch.no_grad(): for param in _pipe.vae.parameters(): param.data = quantizer(param.data) except Exception as e: print(f"Quantization warning: {e}") pipe.enable_sequential_cpu_offload() empty_cache() pipe("transliterator, omnicredulous, finicality, scotia, anesthesia", guidance_scale=0.0, max_sequence_length=256, num_inference_steps=4) return pipe @torch.inference_mode() def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image: torch.cuda.reset_peak_memory_stats() if request.seed is None: generator = None else: generator = Generator(device="cuda").manual_seed(request.seed) empty_cache() image = _pipeline(prompt=request.prompt, width=request.width, height=request.height, guidance_scale=0.0, generator=generator, output_type="pil", max_sequence_length=256, num_inference_steps=NUM_STEPS).images[0] return image