File size: 4,776 Bytes
8e3f1b7
3e39c60
 
 
 
8e3f1b7
 
524d6b8
 
 
 
 
 
3e39c60
 
8e3f1b7
524d6b8
3e39c60
524d6b8
3e39c60
 
524d6b8
 
8e3f1b7
 
 
 
524d6b8
8e3f1b7
 
 
 
 
 
 
 
 
 
 
524d6b8
8e3f1b7
524d6b8
8e3f1b7
 
524d6b8
8e3f1b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524d6b8
8e3f1b7
 
 
524d6b8
 
8e3f1b7
524d6b8
8e3f1b7
3e39c60
8e3f1b7
524d6b8
 
 
 
 
 
 
8e3f1b7
 
524d6b8
 
 
 
 
 
8e3f1b7
524d6b8
8e3f1b7
524d6b8
 
 
 
8e3f1b7
 
 
 
 
 
524d6b8
 
8e3f1b7
 
524d6b8
 
 
 
 
 
 
 
 
3e39c60
 
524d6b8
 
 
 
 
 
 
 
 
 
 
3e39c60
 
524d6b8
3e39c60
 
 
 
 
 
 
 
 
 
524d6b8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import torch
import torch._dynamo
import gc
from PIL.Image import Image
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import (
    T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
)
from diffusers import (
    FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline
)
from pipelines.models import TextToImageRequest
from torch import Generator

# Set environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "True"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
torch._dynamo.config.suppress_errors = True
Pipeline = None

# Define constants
CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"


class QuantativeAnalysis:
    def __init__(self, model, num_bins=256, scale_ratio=1.0):
        self.model = model
        self.num_bins = num_bins
        self.scale_ratio = scale_ratio

    def apply(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                with torch.no_grad():
                    param_min = param.min()
                    param_max = param.max()
                    param_range = param_max - param_min
                    if param_range > 0:
                        params = 0.8 * param_min + 0.2 * param_max
        return self.model


class AttentionQuant:
    def __init__(self, model, att_config):
        self.model = model
        self.att_config = att_config

    def apply(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                layer_name = name.split(".")[0]
                if layer_name in self.att_config:
                    num_bins, scale_factor = self.att_config[layer_name]
                    with torch.no_grad():
                        param_min = param.min()
                        param_max = param.max()
                        param_range = param_max - param_min
                        
                        if param_range > 0:
                            normalized = (param - param_min) / param_range
                            binned = torch.round(normalized * (num_bins - 1)) / (num_bins - 1)
                            rescaled = binned * param_range + param_min
                            param.data.copy_(rescaled * scale_factor)
                        else:
                            param.data.zero_()
        return self.model


def load_pipeline() -> Pipeline:
    # Load T5 model
    __t5_model = T5EncoderModel.from_pretrained(
        "TrendForge/extra1manQ1",
        revision="d302b6e39214ed4532be34ec337f93c7eef3eaa6",
        torch_dtype=torch.bfloat16
    ).to(memory_format=torch.channels_last)
    __text_encoder_2 = __t5_model

    # Load VAE
    base_vae = AutoencoderTiny.from_pretrained(
        "TrendForge/extra2manQ2",
        revision="cef012d2db2f5a006567e797a0b9130aea5449c1",
        torch_dtype=torch.bfloat16
    )

    # Load Transformer Model
    path = os.path.join(HF_HUB_CACHE, "models--TrendForge--extra0manQ0/snapshots/dc2cda167b8f53792a98020a3ef2f21808b09bb4")
    base_trans = FluxTransformer2DModel.from_pretrained(
        path, torch_dtype=torch.bfloat16, use_safetensors=False
    ).to(memory_format=torch.channels_last)
    
    try:
        att_config = {
            "transformer_blocks.15.attn.norm_added_k.weight": (64, 0.1),
            "transformer_blocks.15.attn.norm_added_q.weight": (64, 0.1),
            "transformer_blocks.15.attn.norm_added_v.weight": (64, 0.1)
        }
        transformer = AttentionQuant(base_trans, att_config).apply()
    except Exception:
        transformer = base_trans

    # Load pipeline
    pipeline = DiffusionPipeline.from_pretrained(
        CHECKPOINT,
        revision=REVISION,
        vae=base_vae,
        transformer=transformer,
        text_encoder_2=__text_encoder_2,
        torch_dtype=torch.bfloat16
    )
    pipeline.to("cuda")

    # Warmup
    for _ in range(3):
        pipeline(
            prompt="forswearer, skullcap, Juglandales, bluelegs, cunila, carbro, Ammonites",
            width=1024,
            height=1024,
            guidance_scale=0.0,
            num_inference_steps=4,
            max_sequence_length=256
        )
    
    return pipeline


@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    generator = Generator(pipeline.device).manual_seed(request.seed)
    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width
    ).images[0]