File size: 7,961 Bytes
8f9a9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import os
import gc
import json
import math
from typing import Any, Dict

import torch
from torch import Generator
import torch._dynamo

import transformers
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
from huggingface_hub.constants import HF_HUB_CACHE

from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderTiny
from pipelines.models import TextToImageRequest

from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only

from PIL.Image import Image

# -----------------------------------------------------------------------------
# Environment Configuration & Global Constants
# -----------------------------------------------------------------------------
torch._dynamo.config.suppress_errors = True
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"

# Identifiers for the diffusion model checkpoint.
MODEL_ID = "black-forest-labs/FLUX.1-schnell"
MODEL_REV = "741f7c3ce8b383c54771c7003378a50191e9efe9"


# -----------------------------------------------------------------------------
# Quantization and Linear Transformation Utilities
# -----------------------------------------------------------------------------
def perform_linear_quant(
    input_tensor: torch.Tensor,
    weight_tensor: torch.Tensor,
    w_scale: float,
    w_zero: int,
    in_scale: float,
    in_zero: int,
    out_scale: float,
    out_zero: int,
) -> torch.Tensor:
    """
    Performs a quantization-aware linear operation on the input tensor.

    This function first dequantizes both the input and the weights,
    applies a linear transformation, and then requantizes the result.

    Parameters:
        input_tensor (torch.Tensor): The input tensor.
        weight_tensor (torch.Tensor): The weight tensor.
        w_scale (float): Scale factor for the weights.
        w_zero (int): Zero-point for the weights.
        in_scale (float): Scale factor for the input.
        in_zero (int): Zero-point for the input.
        out_scale (float): Scale factor for the output.
        out_zero (int): Zero-point for the output.

    Returns:
        torch.Tensor: The quantized output tensor.
    """
    # Convert to float and dequantize
    inp_deq = input_tensor.float() - in_zero
    wt_deq = weight_tensor.float() - w_zero

    # Standard linear transformation
    lin_result = torch.nn.functional.linear(inp_deq, wt_deq)

    # Requantize the result
    requantized = lin_result * ((in_scale * w_scale) / out_scale) + out_zero
    return torch.clamp(torch.round(requantized), 0, 255)


# -----------------------------------------------------------------------------
# Model Initialization Functions
# -----------------------------------------------------------------------------
def initialize_text_encoder() -> T5EncoderModel:
    """
    Loads the T5 text encoder and returns it in a channels-last format.
    """
    print("Initializing T5 text encoder...")
    encoder = T5EncoderModel.from_pretrained(
        "city96/t5-v1_1-xxl-encoder-bf16",
        revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
        torch_dtype=torch.bfloat16,
    )
    return encoder.to(memory_format=torch.channels_last)


def initialize_transformer(transformer_dir: str) -> FluxTransformer2DModel:
    """
    Loads the Flux transformer model from a specified directory.
    """
    print("Initializing Flux transformer...")
    transformer = FluxTransformer2DModel.from_pretrained(
        transformer_dir,
        torch_dtype=torch.bfloat16,
        use_safetensors=False,
    )
    return transformer.to(memory_format=torch.channels_last)


# -----------------------------------------------------------------------------
# Pipeline Construction
# -----------------------------------------------------------------------------
def load_pipeline() -> DiffusionPipeline:
    """
    Constructs the diffusion pipeline by combining the text encoder and transformer.
    
    This function also applies a dummy quantization operation to the linear
    submodules of the transformer and enables VAE tiling. Finally, it performs
    several warm-up calls to stabilize performance.
    
    Returns:
        DiffusionPipeline: The configured diffusion pipeline.
    """

    # Build the path to the transformer snapshot.
    transformer_dir = os.path.join(
        HF_HUB_CACHE,
        "models--park234--FLUX1-SCHENELL-INT8/snapshots/59c2f006f045d9ccdc2e3ab02150b8df0adfafc6",
    )
    transformer_model = initialize_transformer(transformer_dir)
    
    encoder = initialize_text_encoder()

    pipeline_instance = DiffusionPipeline.from_pretrained(
        MODEL_ID,
        revision=MODEL_REV,
        transformer=transformer_model,
        text_encoder_2=encoder,
        torch_dtype=torch.bfloat16,
    ).to("cuda")

    try:
        # Process each linear layer in the transformer for quantization adjustments.
        linear_modules = [
            mod for mod in pipeline_instance.transformer.layers
            if "Linear" in mod.__classname__
        ]
        for mod in linear_modules:
            dummy_input = torch.randn(1, 256)  # Dummy tensor for demonstration.
            # Perform a dummy quantization adjustment using exponential notation.
            _ = perform_linear_quant(
                input_tensor=dummy_input,
                weight_tensor=mod.weight,
                w_scale=1e-1,
                w_zero=0,
                in_scale=1e-1,
                in_zero=0,
                out_scale=1e-1,
                out_zero=0,
            )
        pipeline_instance.vae.enable_vae_tiling()
    except Exception as err:
        print("Warning: Quantization adjustments or VAE tiling failed:", err)

    # Run several warm-up inferences.
    warmup_prompt = "unrectangular, uneucharistical, pouchful, uplay, person"
    for _ in range(3):
        _ = pipeline_instance(
            prompt=warmup_prompt,
            width=1024,
            height=1024,
            guidance_scale=0.0,
            num_inference_steps=4,
            max_sequence_length=256,
        )
    return pipeline_instance


# -----------------------------------------------------------------------------
# Inference Function
# -----------------------------------------------------------------------------
@torch.no_grad()
def inference(request: TextToImageRequest, pipeline: DiffusionPipeline) -> Image:
    """
    Generates an image based on the provided text prompt and image parameters.
    
    The function clears the GPU cache, seeds the random generator, and calls the
    diffusion pipeline to produce the output image.
    
    Parameters:
        request (TextToImageRequest): Contains prompt, height, width, and seed.
        pipeline (DiffusionPipeline): The diffusion pipeline to run inference.

    Returns:
        Image: The generated image.
    """
    torch.cuda.empty_cache()
    rnd_gen = Generator(pipeline.device).manual_seed(request.seed)
    output = pipeline(
        request.prompt,
        generator=rnd_gen,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
        output_type="pil"
    )
    return output.images[0]


# -----------------------------------------------------------------------------
# Example Main Flow (Optional)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    # Construct the diffusion pipeline.
    diffusion_pipe = load_pipeline()

    # Create a sample request (assuming TextToImageRequest is appropriately defined).
    sample_request = TextToImageRequest(
        prompt="a scenic view of mountains at sunrise",
        height=512,
        width=512,
        seed=1234
    )

    # Generate an image.
    result_image = inference(sample_request, diffusion_pipe)
    # Here, you may save or display 'result_image' as desired.