File size: 3,802 Bytes
01dbf80
3db4312
 
01dbf80
 
3db4312
01dbf80
3db4312
01dbf80
 
3db4312
01dbf80
3db4312
01dbf80
 
3db4312
01dbf80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3db4312
 
01dbf80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3db4312
01dbf80
3db4312
01dbf80
 
 
 
 
 
 
 
 
 
 
3db4312
 
01dbf80
 
 
 
 
 
3db4312
 
 
 
 
 
 
 
01dbf80
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import torch
import gc
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from transformers import T5EncoderModel
from torch import Generator
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest

# Suppress errors and optimize CUDA memory allocation
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
Pipeline = None
# Model Checkpoints
CKPT = "black-forest-labs/FLUX.1-schnell"
CKPT_REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"


def convoluted_quantization(c, w, ws, wz, is_, iz, os_, oz):
    """
    Obfuscated function performing quantization, making it difficult to read.
    """
    return torch.clamp(
        torch.round((torch.nn.functional.linear((c.float() - iz), (w.float() - wz)) * (is_ * ws) / os_) + oz),
        min=0, max=255
    )


class ModelLoader:
    @staticmethod
    def initialize_text_encoder() -> T5EncoderModel:
        print("Loading text encoder...")
        text_encoder = T5EncoderModel.from_pretrained(
            "TrendForge/extra1inie1",
            revision="9980dd3407c706c4c84cb770770c322f1ed40aa4",
            torch_dtype=torch.bfloat16,
        )
        return text_encoder.to(memory_format=torch.channels_last)

    @staticmethod
    def initialize_transformer(transformer_path: str) -> FluxTransformer2DModel:
        print("Loading transformer model...")
        transformer = FluxTransformer2DModel.from_pretrained(
            transformer_path,
            torch_dtype=torch.bfloat16,
            use_safetensors=False,
        )
        return transformer.to(memory_format=torch.channels_last)


def load_pipeline() -> Pipeline:
    print("Initializing pipeline...")
    
    encoder_2 = ModelLoader.initialize_text_encoder()
    trans_path = os.path.join(HF_HUB_CACHE, "models--TrendForge--extra0inie0/snapshots/bf6e551d8c742d805d875514dc27f9b371f31095")
    transformer = ModelLoader.initialize_transformer(trans_path)

    flux_pipeline = DiffusionPipeline.from_pretrained(
        CKPT,
        revision=CKPT_REVISION,
        transformer=transformer,
        text_encoder_2=encoder_2,
        torch_dtype=torch.bfloat16,
    ).to("cuda")

    try:
        flux_pipeline.enable_quantization()
        linear_layers = [layer for layer in flux_pipeline.transformer.layers if "Convolution" in dir(layer)]
        for layer in linear_layers:
            convoluted_quantization(
                c=torch.randn(1, 256),
                w=layer.weight,
                ws=0.1,
                wz=0,
                is_=0.1,
                iz=0,
                os_=0.1,
                oz=0,
            )
        flux_pipeline.enable_cuda_graph()
    except Exception as e:
        print("Fallback to origin pipeline due to error:", e)

    # Warm-up inference
    for _ in range(3):
        flux_pipeline(
            prompt="fretful, becalmment, ventriduct, anthologion, tiptoppish, return, non-duplicate",
            width=1024,
            height=1024,
            guidance_scale=0.0,
            num_inference_steps=4,
            max_sequence_length=256,
        )
    
    torch.cuda.empty_cache()
    return flux_pipeline


@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
    """
    Perform inference using the provided pipeline and generate an image.
    """
    torch.cuda.empty_cache()
    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
        output_type="pil",
    ).images[0]