|
|
import os |
|
|
import gc |
|
|
import json |
|
|
import math |
|
|
from typing import Any, Dict |
|
|
|
|
|
import torch |
|
|
from torch import Generator |
|
|
import torch._dynamo |
|
|
|
|
|
import transformers |
|
|
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel |
|
|
from huggingface_hub.constants import HF_HUB_CACHE |
|
|
|
|
|
from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderTiny |
|
|
from pipelines.models import TextToImageRequest |
|
|
|
|
|
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only |
|
|
|
|
|
from PIL.Image import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch._dynamo.config.suppress_errors = True |
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "True" |
|
|
|
|
|
|
|
|
MODEL_ID = "black-forest-labs/FLUX.1-schnell" |
|
|
MODEL_REV = "741f7c3ce8b383c54771c7003378a50191e9efe9" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def perform_linear_quant( |
|
|
input_tensor: torch.Tensor, |
|
|
weight_tensor: torch.Tensor, |
|
|
w_scale: float, |
|
|
w_zero: int, |
|
|
in_scale: float, |
|
|
in_zero: int, |
|
|
out_scale: float, |
|
|
out_zero: int, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Performs a quantization-aware linear operation on the input tensor. |
|
|
|
|
|
This function first dequantizes both the input and the weights, |
|
|
applies a linear transformation, and then requantizes the result. |
|
|
|
|
|
Parameters: |
|
|
input_tensor (torch.Tensor): The input tensor. |
|
|
weight_tensor (torch.Tensor): The weight tensor. |
|
|
w_scale (float): Scale factor for the weights. |
|
|
w_zero (int): Zero-point for the weights. |
|
|
in_scale (float): Scale factor for the input. |
|
|
in_zero (int): Zero-point for the input. |
|
|
out_scale (float): Scale factor for the output. |
|
|
out_zero (int): Zero-point for the output. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The quantized output tensor. |
|
|
""" |
|
|
|
|
|
inp_deq = input_tensor.float() - in_zero |
|
|
wt_deq = weight_tensor.float() - w_zero |
|
|
|
|
|
|
|
|
lin_result = torch.nn.functional.linear(inp_deq, wt_deq) |
|
|
|
|
|
|
|
|
requantized = lin_result * ((in_scale * w_scale) / out_scale) + out_zero |
|
|
return torch.clamp(torch.round(requantized), 0, 255) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_text_encoder() -> T5EncoderModel: |
|
|
""" |
|
|
Loads the T5 text encoder and returns it in a channels-last format. |
|
|
""" |
|
|
print("Initializing T5 text encoder...") |
|
|
encoder = T5EncoderModel.from_pretrained( |
|
|
"city96/t5-v1_1-xxl-encoder-bf16", |
|
|
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86", |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
return encoder.to(memory_format=torch.channels_last) |
|
|
|
|
|
|
|
|
def initialize_transformer(transformer_dir: str) -> FluxTransformer2DModel: |
|
|
""" |
|
|
Loads the Flux transformer model from a specified directory. |
|
|
""" |
|
|
print("Initializing Flux transformer...") |
|
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
|
transformer_dir, |
|
|
torch_dtype=torch.bfloat16, |
|
|
use_safetensors=False, |
|
|
) |
|
|
return transformer.to(memory_format=torch.channels_last) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_pipeline() -> DiffusionPipeline: |
|
|
""" |
|
|
Constructs the diffusion pipeline by combining the text encoder and transformer. |
|
|
|
|
|
This function also applies a dummy quantization operation to the linear |
|
|
submodules of the transformer and enables VAE tiling. Finally, it performs |
|
|
several warm-up calls to stabilize performance. |
|
|
|
|
|
Returns: |
|
|
DiffusionPipeline: The configured diffusion pipeline. |
|
|
""" |
|
|
|
|
|
|
|
|
transformer_dir = os.path.join( |
|
|
HF_HUB_CACHE, |
|
|
"models--park234--FLUX1-SCHENELL-INT8/snapshots/59c2f006f045d9ccdc2e3ab02150b8df0adfafc6", |
|
|
) |
|
|
transformer_model = initialize_transformer(transformer_dir) |
|
|
|
|
|
encoder = initialize_text_encoder() |
|
|
|
|
|
pipeline_instance = DiffusionPipeline.from_pretrained( |
|
|
MODEL_ID, |
|
|
revision=MODEL_REV, |
|
|
transformer=transformer_model, |
|
|
text_encoder_2=encoder, |
|
|
torch_dtype=torch.bfloat16, |
|
|
).to("cuda") |
|
|
|
|
|
try: |
|
|
|
|
|
linear_modules = [ |
|
|
mod for mod in pipeline_instance.transformer.layers |
|
|
if "Linear" in mod.__classname__ |
|
|
] |
|
|
for mod in linear_modules: |
|
|
dummy_input = torch.randn(1, 256) |
|
|
|
|
|
_ = perform_linear_quant( |
|
|
input_tensor=dummy_input, |
|
|
weight_tensor=mod.weight, |
|
|
w_scale=1e-1, |
|
|
w_zero=0, |
|
|
in_scale=1e-1, |
|
|
in_zero=0, |
|
|
out_scale=1e-1, |
|
|
out_zero=0, |
|
|
) |
|
|
pipeline_instance.vae.enable_vae_tiling() |
|
|
except Exception as err: |
|
|
print("Warning: Quantization adjustments or VAE tiling failed:", err) |
|
|
|
|
|
|
|
|
warmup_prompt = "unrectangular, uneucharistical, pouchful, uplay, person" |
|
|
for _ in range(3): |
|
|
_ = pipeline_instance( |
|
|
prompt=warmup_prompt, |
|
|
width=1024, |
|
|
height=1024, |
|
|
guidance_scale=0.0, |
|
|
num_inference_steps=4, |
|
|
max_sequence_length=256, |
|
|
) |
|
|
return pipeline_instance |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def inference(request: TextToImageRequest, pipeline: DiffusionPipeline) -> Image: |
|
|
""" |
|
|
Generates an image based on the provided text prompt and image parameters. |
|
|
|
|
|
The function clears the GPU cache, seeds the random generator, and calls the |
|
|
diffusion pipeline to produce the output image. |
|
|
|
|
|
Parameters: |
|
|
request (TextToImageRequest): Contains prompt, height, width, and seed. |
|
|
pipeline (DiffusionPipeline): The diffusion pipeline to run inference. |
|
|
|
|
|
Returns: |
|
|
Image: The generated image. |
|
|
""" |
|
|
torch.cuda.empty_cache() |
|
|
rnd_gen = Generator(pipeline.device).manual_seed(request.seed) |
|
|
output = pipeline( |
|
|
request.prompt, |
|
|
generator=rnd_gen, |
|
|
guidance_scale=0.0, |
|
|
num_inference_steps=4, |
|
|
max_sequence_length=256, |
|
|
height=request.height, |
|
|
width=request.width, |
|
|
output_type="pil" |
|
|
) |
|
|
return output.images[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
diffusion_pipe = load_pipeline() |
|
|
|
|
|
|
|
|
sample_request = TextToImageRequest( |
|
|
prompt="a scenic view of mountains at sunrise", |
|
|
height=512, |
|
|
width=512, |
|
|
seed=1234 |
|
|
) |
|
|
|
|
|
|
|
|
result_image = inference(sample_request, diffusion_pipe) |
|
|
|
|
|
|