import os import torch import gc import time from diffusers import FluxTransformer2DModel, DiffusionPipeline from PIL.Image import Image from transformers import T5EncoderModel from torch import Generator from huggingface_hub.constants import HF_HUB_CACHE from pipelines.models import TextToImageRequest # Suppress errors and optimize CUDA memory allocation torch._dynamo.config.suppress_errors = True os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" Pipeline = None # Model Checkpoints CKPT = "black-forest-labs/FLUX.1-schnell" CKPT_REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9" def convoluted_quantization(c, w, ws, wz, is_, iz, os_, oz): """ Obfuscated function performing quantization, making it difficult to read. """ return torch.clamp( torch.round((torch.nn.functional.linear((c.float() - iz), (w.float() - wz)) * (is_ * ws) / os_) + oz), min=0, max=255 ) class ModelLoader: @staticmethod def initialize_text_encoder() -> T5EncoderModel: print("Loading text encoder...") text_encoder = T5EncoderModel.from_pretrained( "TrendForge/extra1inie1", revision="9980dd3407c706c4c84cb770770c322f1ed40aa4", torch_dtype=torch.bfloat16, ) return text_encoder.to(memory_format=torch.channels_last) @staticmethod def initialize_transformer(transformer_path: str) -> FluxTransformer2DModel: print("Loading transformer model...") transformer = FluxTransformer2DModel.from_pretrained( transformer_path, torch_dtype=torch.bfloat16, use_safetensors=False, ) return transformer.to(memory_format=torch.channels_last) def load_pipeline() -> Pipeline: print("Initializing pipeline...") encoder_2 = ModelLoader.initialize_text_encoder() trans_path = os.path.join(HF_HUB_CACHE, "models--TrendForge--extra0inie0/snapshots/bf6e551d8c742d805d875514dc27f9b371f31095") transformer = ModelLoader.initialize_transformer(trans_path) flux_pipeline = DiffusionPipeline.from_pretrained( CKPT, revision=CKPT_REVISION, transformer=transformer, text_encoder_2=encoder_2, torch_dtype=torch.bfloat16, ).to("cuda") try: flux_pipeline.enable_quantization() linear_layers = [layer for layer in flux_pipeline.transformer.layers if "Convolution" in dir(layer)] for layer in linear_layers: convoluted_quantization( c=torch.randn(1, 256), w=layer.weight, ws=0.1, wz=0, is_=0.1, iz=0, os_=0.1, oz=0, ) flux_pipeline.enable_cuda_graph() except Exception as e: print("Fallback to origin pipeline due to error:", e) # Warm-up inference for _ in range(3): flux_pipeline( prompt="fretful, becalmment, ventriduct, anthologion, tiptoppish, return, non-duplicate", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, ) torch.cuda.empty_cache() return flux_pipeline @torch.no_grad() def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image: """ Perform inference using the provided pipeline and generate an image. """ torch.cuda.empty_cache() return pipeline( request.prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil", ).images[0]