import os import gc import json import math from typing import Any, Dict import torch from torch import Generator import torch._dynamo import transformers from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel from huggingface_hub.constants import HF_HUB_CACHE from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderTiny from pipelines.models import TextToImageRequest from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only from PIL.Image import Image # ----------------------------------------------------------------------------- # Environment Configuration & Global Constants # ----------------------------------------------------------------------------- torch._dynamo.config.suppress_errors = True os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" # Identifiers for the diffusion model checkpoint. MODEL_ID = "black-forest-labs/FLUX.1-schnell" MODEL_REV = "741f7c3ce8b383c54771c7003378a50191e9efe9" # ----------------------------------------------------------------------------- # Quantization and Linear Transformation Utilities # ----------------------------------------------------------------------------- def perform_linear_quant( input_tensor: torch.Tensor, weight_tensor: torch.Tensor, w_scale: float, w_zero: int, in_scale: float, in_zero: int, out_scale: float, out_zero: int, ) -> torch.Tensor: """ Performs a quantization-aware linear operation on the input tensor. This function first dequantizes both the input and the weights, applies a linear transformation, and then requantizes the result. Parameters: input_tensor (torch.Tensor): The input tensor. weight_tensor (torch.Tensor): The weight tensor. w_scale (float): Scale factor for the weights. w_zero (int): Zero-point for the weights. in_scale (float): Scale factor for the input. in_zero (int): Zero-point for the input. out_scale (float): Scale factor for the output. out_zero (int): Zero-point for the output. Returns: torch.Tensor: The quantized output tensor. """ # Convert to float and dequantize inp_deq = input_tensor.float() - in_zero wt_deq = weight_tensor.float() - w_zero # Standard linear transformation lin_result = torch.nn.functional.linear(inp_deq, wt_deq) # Requantize the result requantized = lin_result * ((in_scale * w_scale) / out_scale) + out_zero return torch.clamp(torch.round(requantized), 0, 255) # ----------------------------------------------------------------------------- # Model Initialization Functions # ----------------------------------------------------------------------------- def initialize_text_encoder() -> T5EncoderModel: """ Loads the T5 text encoder and returns it in a channels-last format. """ print("Initializing T5 text encoder...") encoder = T5EncoderModel.from_pretrained( "city96/t5-v1_1-xxl-encoder-bf16", revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16, ) return encoder.to(memory_format=torch.channels_last) def initialize_transformer(transformer_dir: str) -> FluxTransformer2DModel: """ Loads the Flux transformer model from a specified directory. """ print("Initializing Flux transformer...") transformer = FluxTransformer2DModel.from_pretrained( transformer_dir, torch_dtype=torch.bfloat16, use_safetensors=False, ) return transformer.to(memory_format=torch.channels_last) # ----------------------------------------------------------------------------- # Pipeline Construction # ----------------------------------------------------------------------------- def load_pipeline() -> DiffusionPipeline: """ Constructs the diffusion pipeline by combining the text encoder and transformer. This function also applies a dummy quantization operation to the linear submodules of the transformer and enables VAE tiling. Finally, it performs several warm-up calls to stabilize performance. Returns: DiffusionPipeline: The configured diffusion pipeline. """ # Build the path to the transformer snapshot. transformer_dir = os.path.join( HF_HUB_CACHE, "models--park234--FLUX1-SCHENELL-INT8/snapshots/59c2f006f045d9ccdc2e3ab02150b8df0adfafc6", ) transformer_model = initialize_transformer(transformer_dir) encoder = initialize_text_encoder() pipeline_instance = DiffusionPipeline.from_pretrained( MODEL_ID, revision=MODEL_REV, transformer=transformer_model, text_encoder_2=encoder, torch_dtype=torch.bfloat16, ).to("cuda") try: # Process each linear layer in the transformer for quantization adjustments. linear_modules = [ mod for mod in pipeline_instance.transformer.layers if "Linear" in mod.__classname__ ] for mod in linear_modules: dummy_input = torch.randn(1, 256) # Dummy tensor for demonstration. # Perform a dummy quantization adjustment using exponential notation. _ = perform_linear_quant( input_tensor=dummy_input, weight_tensor=mod.weight, w_scale=1e-1, w_zero=0, in_scale=1e-1, in_zero=0, out_scale=1e-1, out_zero=0, ) pipeline_instance.vae.enable_vae_tiling() except Exception as err: print("Warning: Quantization adjustments or VAE tiling failed:", err) # Run several warm-up inferences. warmup_prompt = "unrectangular, uneucharistical, pouchful, uplay, person" for _ in range(3): _ = pipeline_instance( prompt=warmup_prompt, width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, ) return pipeline_instance # ----------------------------------------------------------------------------- # Inference Function # ----------------------------------------------------------------------------- @torch.no_grad() def inference(request: TextToImageRequest, pipeline: DiffusionPipeline) -> Image: """ Generates an image based on the provided text prompt and image parameters. The function clears the GPU cache, seeds the random generator, and calls the diffusion pipeline to produce the output image. Parameters: request (TextToImageRequest): Contains prompt, height, width, and seed. pipeline (DiffusionPipeline): The diffusion pipeline to run inference. Returns: Image: The generated image. """ torch.cuda.empty_cache() rnd_gen = Generator(pipeline.device).manual_seed(request.seed) output = pipeline( request.prompt, generator=rnd_gen, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil" ) return output.images[0] # ----------------------------------------------------------------------------- # Example Main Flow (Optional) # ----------------------------------------------------------------------------- if __name__ == "__main__": # Construct the diffusion pipeline. diffusion_pipe = load_pipeline() # Create a sample request (assuming TextToImageRequest is appropriately defined). sample_request = TextToImageRequest( prompt="a scenic view of mountains at sunrise", height=512, width=512, seed=1234 ) # Generate an image. result_image = inference(sample_request, diffusion_pipe) # Here, you may save or display 'result_image' as desired.