import os import torch import torch._dynamo import gc import json import transformers from huggingface_hub.constants import HF_HUB_CACHE from transformers import T5EncoderModel, T5TokenizerFast from PIL.Image import Image from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline from pipelines.models import TextToImageRequest from optimum.quanto import requantize from torch import Generator from torch._dynamo import config from torch._inductor import config as ind_config from typing import Dict, Any, Callable from functools import wraps def error_handler(func: Callable): @wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: print(f"Error in {func.__name__}: {str(e)}") return wrapper class TorchOptimizer: def optimize_settings(self): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.set_float32_matmul_precision("high") def clear_cache(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() class PipelineManager: def __init__(self): self.ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2" self.revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665" self.pipeline = None self.optimizer = TorchOptimizer() # Configure environment torch._dynamo.config.suppress_errors = True os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" # Initialize torch settings self.optimizer.optimize_settings() def load_transformer(self): transformer_path = os.path.join( HF_HUB_CACHE, "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665" ) return FluxTransformer2DModel.from_pretrained( transformer_path, torch_dtype=torch.bfloat16, use_safetensors=False ) @error_handler def optimize_pipeline(self, pipe): # Fuse QKV projections pipe.transformer.fuse_qkv_projections() pipe.vae.fuse_qkv_projections() # Optimize memory layout pipe.transformer.to(memory_format=torch.channels_last) pipe.vae.to(memory_format=torch.channels_last) # Configure torch inductor config = torch._inductor.config config.disable_progress = False config.conv_1x1_as_mm = True # Compile modules pipe.transformer = torch.compile( pipe.transformer, mode="max-autotune", fullgraph=True ) pipe.vae.decode = torch.compile( pipe.vae.decode, mode="max-autotune", fullgraph=True ) return pipe def load_pipeline(self): # Load transformer model transformer_model = self.load_transformer() # Create pipeline pipe = DiffusionPipeline.from_pretrained( self.ckpt_root, revision=self.revision_root, transformer=transformer_model, torch_dtype=torch.bfloat16 ) pipe.to("cuda") # Optimize pipeline pipe = self.optimize_pipeline(pipe) # Trigger compilation print("Running torch compilation...") pipe( "dummy prompt to trigger torch compilation", output_type="pil", num_inference_steps=4 ).images[0] print("Finished torch compilation") return pipe def run_inference(self, request: TextToImageRequest) -> Image: if self.pipeline is None: self.pipeline = self.load_pipeline() self.optimizer.clear_cache() generator = Generator(self.pipeline.device).manual_seed(request.seed) return self.pipeline( request.prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, ).images[0]