| import os |
| import gc |
| import json |
| import math |
| from typing import Any, Dict |
|
|
| import torch |
| from torch import Generator |
| import torch._dynamo |
|
|
| import transformers |
| from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel |
| from huggingface_hub.constants import HF_HUB_CACHE |
|
|
| from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderTiny |
| from pipelines.models import TextToImageRequest |
|
|
| from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only |
|
|
| from PIL.Image import Image |
|
|
| |
| |
| |
| torch._dynamo.config.suppress_errors = True |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
| os.environ["TOKENIZERS_PARALLELISM"] = "True" |
|
|
| |
| MODEL_ID = "black-forest-labs/FLUX.1-schnell" |
| MODEL_REV = "741f7c3ce8b383c54771c7003378a50191e9efe9" |
|
|
|
|
| |
| |
| |
| def perform_linear_quant( |
| input_tensor: torch.Tensor, |
| weight_tensor: torch.Tensor, |
| w_scale: float, |
| w_zero: int, |
| in_scale: float, |
| in_zero: int, |
| out_scale: float, |
| out_zero: int, |
| ) -> torch.Tensor: |
| """ |
| Performs a quantization-aware linear operation on the input tensor. |
| |
| This function first dequantizes both the input and the weights, |
| applies a linear transformation, and then requantizes the result. |
| |
| Parameters: |
| input_tensor (torch.Tensor): The input tensor. |
| weight_tensor (torch.Tensor): The weight tensor. |
| w_scale (float): Scale factor for the weights. |
| w_zero (int): Zero-point for the weights. |
| in_scale (float): Scale factor for the input. |
| in_zero (int): Zero-point for the input. |
| out_scale (float): Scale factor for the output. |
| out_zero (int): Zero-point for the output. |
| |
| Returns: |
| torch.Tensor: The quantized output tensor. |
| """ |
| |
| inp_deq = input_tensor.float() - in_zero |
| wt_deq = weight_tensor.float() - w_zero |
|
|
| |
| lin_result = torch.nn.functional.linear(inp_deq, wt_deq) |
|
|
| |
| requantized = lin_result * ((in_scale * w_scale) / out_scale) + out_zero |
| return torch.clamp(torch.round(requantized), 0, 255) |
|
|
|
|
| |
| |
| |
| def initialize_text_encoder() -> T5EncoderModel: |
| """ |
| Loads the T5 text encoder and returns it in a channels-last format. |
| """ |
| print("Initializing T5 text encoder...") |
| encoder = T5EncoderModel.from_pretrained( |
| "city96/t5-v1_1-xxl-encoder-bf16", |
| revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86", |
| torch_dtype=torch.bfloat16, |
| ) |
| return encoder.to(memory_format=torch.channels_last) |
|
|
|
|
| def initialize_transformer(transformer_dir: str) -> FluxTransformer2DModel: |
| """ |
| Loads the Flux transformer model from a specified directory. |
| """ |
| print("Initializing Flux transformer...") |
| transformer = FluxTransformer2DModel.from_pretrained( |
| transformer_dir, |
| torch_dtype=torch.bfloat16, |
| use_safetensors=False, |
| ) |
| return transformer.to(memory_format=torch.channels_last) |
|
|
|
|
| |
| |
| |
| def load_pipeline() -> DiffusionPipeline: |
| """ |
| Constructs the diffusion pipeline by combining the text encoder and transformer. |
| |
| This function also applies a dummy quantization operation to the linear |
| submodules of the transformer and enables VAE tiling. Finally, it performs |
| several warm-up calls to stabilize performance. |
| |
| Returns: |
| DiffusionPipeline: The configured diffusion pipeline. |
| """ |
|
|
| |
| transformer_dir = os.path.join( |
| HF_HUB_CACHE, |
| "models--park234--FLUX1-SCHENELL-INT8/snapshots/59c2f006f045d9ccdc2e3ab02150b8df0adfafc6", |
| ) |
| transformer_model = initialize_transformer(transformer_dir) |
| |
| encoder = initialize_text_encoder() |
|
|
| pipeline_instance = DiffusionPipeline.from_pretrained( |
| MODEL_ID, |
| revision=MODEL_REV, |
| transformer=transformer_model, |
| text_encoder_2=encoder, |
| torch_dtype=torch.bfloat16, |
| ).to("cuda") |
|
|
| try: |
| |
| linear_modules = [ |
| mod for mod in pipeline_instance.transformer.layers |
| if "Linear" in mod.__classname__ |
| ] |
| for mod in linear_modules: |
| dummy_input = torch.randn(1, 256) |
| |
| _ = perform_linear_quant( |
| input_tensor=dummy_input, |
| weight_tensor=mod.weight, |
| w_scale=1e-1, |
| w_zero=0, |
| in_scale=1e-1, |
| in_zero=0, |
| out_scale=1e-1, |
| out_zero=0, |
| ) |
| pipeline_instance.vae.enable_vae_tiling() |
| except Exception as err: |
| print("Warning: Quantization adjustments or VAE tiling failed:", err) |
|
|
| |
| warmup_prompt = "unrectangular, uneucharistical, pouchful, uplay, person" |
| for _ in range(3): |
| _ = pipeline_instance( |
| prompt=warmup_prompt, |
| width=1024, |
| height=1024, |
| guidance_scale=0.0, |
| num_inference_steps=4, |
| max_sequence_length=256, |
| ) |
| return pipeline_instance |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def inference(request: TextToImageRequest, pipeline: DiffusionPipeline) -> Image: |
| """ |
| Generates an image based on the provided text prompt and image parameters. |
| |
| The function clears the GPU cache, seeds the random generator, and calls the |
| diffusion pipeline to produce the output image. |
| |
| Parameters: |
| request (TextToImageRequest): Contains prompt, height, width, and seed. |
| pipeline (DiffusionPipeline): The diffusion pipeline to run inference. |
| |
| Returns: |
| Image: The generated image. |
| """ |
| torch.cuda.empty_cache() |
| rnd_gen = Generator(pipeline.device).manual_seed(request.seed) |
| output = pipeline( |
| request.prompt, |
| generator=rnd_gen, |
| guidance_scale=0.0, |
| num_inference_steps=4, |
| max_sequence_length=256, |
| height=request.height, |
| width=request.width, |
| output_type="pil" |
| ) |
| return output.images[0] |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| |
| diffusion_pipe = load_pipeline() |
|
|
| |
| sample_request = TextToImageRequest( |
| prompt="a scenic view of mountains at sunrise", |
| height=512, |
| width=512, |
| seed=1234 |
| ) |
|
|
| |
| result_image = inference(sample_request, diffusion_pipe) |
| |
|
|