Pulse-AI-651 / src /pipeline.py
sharper740's picture
Upload folder using huggingface_hub
8f9a9c3 verified
import os
import gc
import json
import math
from typing import Any, Dict
import torch
from torch import Generator
import torch._dynamo
import transformers
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
from huggingface_hub.constants import HF_HUB_CACHE
from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderTiny
from pipelines.models import TextToImageRequest
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
from PIL.Image import Image
# -----------------------------------------------------------------------------
# Environment Configuration & Global Constants
# -----------------------------------------------------------------------------
torch._dynamo.config.suppress_errors = True
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
# Identifiers for the diffusion model checkpoint.
MODEL_ID = "black-forest-labs/FLUX.1-schnell"
MODEL_REV = "741f7c3ce8b383c54771c7003378a50191e9efe9"
# -----------------------------------------------------------------------------
# Quantization and Linear Transformation Utilities
# -----------------------------------------------------------------------------
def perform_linear_quant(
input_tensor: torch.Tensor,
weight_tensor: torch.Tensor,
w_scale: float,
w_zero: int,
in_scale: float,
in_zero: int,
out_scale: float,
out_zero: int,
) -> torch.Tensor:
"""
Performs a quantization-aware linear operation on the input tensor.
This function first dequantizes both the input and the weights,
applies a linear transformation, and then requantizes the result.
Parameters:
input_tensor (torch.Tensor): The input tensor.
weight_tensor (torch.Tensor): The weight tensor.
w_scale (float): Scale factor for the weights.
w_zero (int): Zero-point for the weights.
in_scale (float): Scale factor for the input.
in_zero (int): Zero-point for the input.
out_scale (float): Scale factor for the output.
out_zero (int): Zero-point for the output.
Returns:
torch.Tensor: The quantized output tensor.
"""
# Convert to float and dequantize
inp_deq = input_tensor.float() - in_zero
wt_deq = weight_tensor.float() - w_zero
# Standard linear transformation
lin_result = torch.nn.functional.linear(inp_deq, wt_deq)
# Requantize the result
requantized = lin_result * ((in_scale * w_scale) / out_scale) + out_zero
return torch.clamp(torch.round(requantized), 0, 255)
# -----------------------------------------------------------------------------
# Model Initialization Functions
# -----------------------------------------------------------------------------
def initialize_text_encoder() -> T5EncoderModel:
"""
Loads the T5 text encoder and returns it in a channels-last format.
"""
print("Initializing T5 text encoder...")
encoder = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16",
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
torch_dtype=torch.bfloat16,
)
return encoder.to(memory_format=torch.channels_last)
def initialize_transformer(transformer_dir: str) -> FluxTransformer2DModel:
"""
Loads the Flux transformer model from a specified directory.
"""
print("Initializing Flux transformer...")
transformer = FluxTransformer2DModel.from_pretrained(
transformer_dir,
torch_dtype=torch.bfloat16,
use_safetensors=False,
)
return transformer.to(memory_format=torch.channels_last)
# -----------------------------------------------------------------------------
# Pipeline Construction
# -----------------------------------------------------------------------------
def load_pipeline() -> DiffusionPipeline:
"""
Constructs the diffusion pipeline by combining the text encoder and transformer.
This function also applies a dummy quantization operation to the linear
submodules of the transformer and enables VAE tiling. Finally, it performs
several warm-up calls to stabilize performance.
Returns:
DiffusionPipeline: The configured diffusion pipeline.
"""
# Build the path to the transformer snapshot.
transformer_dir = os.path.join(
HF_HUB_CACHE,
"models--park234--FLUX1-SCHENELL-INT8/snapshots/59c2f006f045d9ccdc2e3ab02150b8df0adfafc6",
)
transformer_model = initialize_transformer(transformer_dir)
encoder = initialize_text_encoder()
pipeline_instance = DiffusionPipeline.from_pretrained(
MODEL_ID,
revision=MODEL_REV,
transformer=transformer_model,
text_encoder_2=encoder,
torch_dtype=torch.bfloat16,
).to("cuda")
try:
# Process each linear layer in the transformer for quantization adjustments.
linear_modules = [
mod for mod in pipeline_instance.transformer.layers
if "Linear" in mod.__classname__
]
for mod in linear_modules:
dummy_input = torch.randn(1, 256) # Dummy tensor for demonstration.
# Perform a dummy quantization adjustment using exponential notation.
_ = perform_linear_quant(
input_tensor=dummy_input,
weight_tensor=mod.weight,
w_scale=1e-1,
w_zero=0,
in_scale=1e-1,
in_zero=0,
out_scale=1e-1,
out_zero=0,
)
pipeline_instance.vae.enable_vae_tiling()
except Exception as err:
print("Warning: Quantization adjustments or VAE tiling failed:", err)
# Run several warm-up inferences.
warmup_prompt = "unrectangular, uneucharistical, pouchful, uplay, person"
for _ in range(3):
_ = pipeline_instance(
prompt=warmup_prompt,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
)
return pipeline_instance
# -----------------------------------------------------------------------------
# Inference Function
# -----------------------------------------------------------------------------
@torch.no_grad()
def inference(request: TextToImageRequest, pipeline: DiffusionPipeline) -> Image:
"""
Generates an image based on the provided text prompt and image parameters.
The function clears the GPU cache, seeds the random generator, and calls the
diffusion pipeline to produce the output image.
Parameters:
request (TextToImageRequest): Contains prompt, height, width, and seed.
pipeline (DiffusionPipeline): The diffusion pipeline to run inference.
Returns:
Image: The generated image.
"""
torch.cuda.empty_cache()
rnd_gen = Generator(pipeline.device).manual_seed(request.seed)
output = pipeline(
request.prompt,
generator=rnd_gen,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil"
)
return output.images[0]
# -----------------------------------------------------------------------------
# Example Main Flow (Optional)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# Construct the diffusion pipeline.
diffusion_pipe = load_pipeline()
# Create a sample request (assuming TextToImageRequest is appropriately defined).
sample_request = TextToImageRequest(
prompt="a scenic view of mountains at sunrise",
height=512,
width=512,
seed=1234
)
# Generate an image.
result_image = inference(sample_request, diffusion_pipe)
# Here, you may save or display 'result_image' as desired.