File size: 7,961 Bytes
8f9a9c3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | import os
import gc
import json
import math
from typing import Any, Dict
import torch
from torch import Generator
import torch._dynamo
import transformers
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
from huggingface_hub.constants import HF_HUB_CACHE
from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderTiny
from pipelines.models import TextToImageRequest
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
from PIL.Image import Image
# -----------------------------------------------------------------------------
# Environment Configuration & Global Constants
# -----------------------------------------------------------------------------
torch._dynamo.config.suppress_errors = True
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
# Identifiers for the diffusion model checkpoint.
MODEL_ID = "black-forest-labs/FLUX.1-schnell"
MODEL_REV = "741f7c3ce8b383c54771c7003378a50191e9efe9"
# -----------------------------------------------------------------------------
# Quantization and Linear Transformation Utilities
# -----------------------------------------------------------------------------
def perform_linear_quant(
input_tensor: torch.Tensor,
weight_tensor: torch.Tensor,
w_scale: float,
w_zero: int,
in_scale: float,
in_zero: int,
out_scale: float,
out_zero: int,
) -> torch.Tensor:
"""
Performs a quantization-aware linear operation on the input tensor.
This function first dequantizes both the input and the weights,
applies a linear transformation, and then requantizes the result.
Parameters:
input_tensor (torch.Tensor): The input tensor.
weight_tensor (torch.Tensor): The weight tensor.
w_scale (float): Scale factor for the weights.
w_zero (int): Zero-point for the weights.
in_scale (float): Scale factor for the input.
in_zero (int): Zero-point for the input.
out_scale (float): Scale factor for the output.
out_zero (int): Zero-point for the output.
Returns:
torch.Tensor: The quantized output tensor.
"""
# Convert to float and dequantize
inp_deq = input_tensor.float() - in_zero
wt_deq = weight_tensor.float() - w_zero
# Standard linear transformation
lin_result = torch.nn.functional.linear(inp_deq, wt_deq)
# Requantize the result
requantized = lin_result * ((in_scale * w_scale) / out_scale) + out_zero
return torch.clamp(torch.round(requantized), 0, 255)
# -----------------------------------------------------------------------------
# Model Initialization Functions
# -----------------------------------------------------------------------------
def initialize_text_encoder() -> T5EncoderModel:
"""
Loads the T5 text encoder and returns it in a channels-last format.
"""
print("Initializing T5 text encoder...")
encoder = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16",
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
torch_dtype=torch.bfloat16,
)
return encoder.to(memory_format=torch.channels_last)
def initialize_transformer(transformer_dir: str) -> FluxTransformer2DModel:
"""
Loads the Flux transformer model from a specified directory.
"""
print("Initializing Flux transformer...")
transformer = FluxTransformer2DModel.from_pretrained(
transformer_dir,
torch_dtype=torch.bfloat16,
use_safetensors=False,
)
return transformer.to(memory_format=torch.channels_last)
# -----------------------------------------------------------------------------
# Pipeline Construction
# -----------------------------------------------------------------------------
def load_pipeline() -> DiffusionPipeline:
"""
Constructs the diffusion pipeline by combining the text encoder and transformer.
This function also applies a dummy quantization operation to the linear
submodules of the transformer and enables VAE tiling. Finally, it performs
several warm-up calls to stabilize performance.
Returns:
DiffusionPipeline: The configured diffusion pipeline.
"""
# Build the path to the transformer snapshot.
transformer_dir = os.path.join(
HF_HUB_CACHE,
"models--park234--FLUX1-SCHENELL-INT8/snapshots/59c2f006f045d9ccdc2e3ab02150b8df0adfafc6",
)
transformer_model = initialize_transformer(transformer_dir)
encoder = initialize_text_encoder()
pipeline_instance = DiffusionPipeline.from_pretrained(
MODEL_ID,
revision=MODEL_REV,
transformer=transformer_model,
text_encoder_2=encoder,
torch_dtype=torch.bfloat16,
).to("cuda")
try:
# Process each linear layer in the transformer for quantization adjustments.
linear_modules = [
mod for mod in pipeline_instance.transformer.layers
if "Linear" in mod.__classname__
]
for mod in linear_modules:
dummy_input = torch.randn(1, 256) # Dummy tensor for demonstration.
# Perform a dummy quantization adjustment using exponential notation.
_ = perform_linear_quant(
input_tensor=dummy_input,
weight_tensor=mod.weight,
w_scale=1e-1,
w_zero=0,
in_scale=1e-1,
in_zero=0,
out_scale=1e-1,
out_zero=0,
)
pipeline_instance.vae.enable_vae_tiling()
except Exception as err:
print("Warning: Quantization adjustments or VAE tiling failed:", err)
# Run several warm-up inferences.
warmup_prompt = "unrectangular, uneucharistical, pouchful, uplay, person"
for _ in range(3):
_ = pipeline_instance(
prompt=warmup_prompt,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
)
return pipeline_instance
# -----------------------------------------------------------------------------
# Inference Function
# -----------------------------------------------------------------------------
@torch.no_grad()
def inference(request: TextToImageRequest, pipeline: DiffusionPipeline) -> Image:
"""
Generates an image based on the provided text prompt and image parameters.
The function clears the GPU cache, seeds the random generator, and calls the
diffusion pipeline to produce the output image.
Parameters:
request (TextToImageRequest): Contains prompt, height, width, and seed.
pipeline (DiffusionPipeline): The diffusion pipeline to run inference.
Returns:
Image: The generated image.
"""
torch.cuda.empty_cache()
rnd_gen = Generator(pipeline.device).manual_seed(request.seed)
output = pipeline(
request.prompt,
generator=rnd_gen,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil"
)
return output.images[0]
# -----------------------------------------------------------------------------
# Example Main Flow (Optional)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# Construct the diffusion pipeline.
diffusion_pipe = load_pipeline()
# Create a sample request (assuming TextToImageRequest is appropriately defined).
sample_request = TextToImageRequest(
prompt="a scenic view of mountains at sunrise",
height=512,
width=512,
seed=1234
)
# Generate an image.
result_image = inference(sample_request, diffusion_pipe)
# Here, you may save or display 'result_image' as desired.
|