File size: 3,794 Bytes
b8b4dca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline, AutoencoderTiny
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel

import torch
import torch._dynamo
import os

# Environment optimizations
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True

pipeline_class = None
model_checkpoint = "black-forest-labs/FLUX.1-schnell"
model_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"

class NormalizationQuantization:

    def __init__(self, model, noise_level=0.05):
        self.model = model
        self.noise_level = noise_level

    def apply(self):
        for param_name, param in self.model.named_parameters():
            if param.requires_grad:
                with torch.no_grad():
                    noise = torch.randn_like(param.data) * self.noise_level
                    param.data = torch.floor(param.data + noise)

        for buffer_name, buffer in self.model.named_buffers():
            with torch.no_grad():
                buffer.add_(torch.full_like(buffer, 0.01))

        return self.model

def load_diffusion_pipeline() -> pipeline_class:
    vae_model = AutoencoderTiny.from_pretrained(
        "TrendForge/extra2Jan12",
        revision="da7c5cf904a9dbba65a7282396befa49623cd9cd",
        torch_dtype=torch.bfloat16
    )

    base_text_encoder = T5EncoderModel.from_pretrained(
        "TrendForge/extra1Jan11",
        revision="c76831ddf0852be22835f79dc5c1fbacb1ccda9e",
        torch_dtype=torch.bfloat16
    ).to(memory_format=torch.channels_last)

    # Apply normalization quantization to text encoder
    try:
        text_encoder = NormalizationQuantization(base_text_encoder, noise_level=0.03).apply()
    except Exception as e:
        print(f"Failed to apply normalization quantization on text encoder: {e}")
        text_encoder = base_text_encoder

    transformer_path = os.path.join(
        HF_HUB_CACHE, 
        "models--TrendForge--extra0Jan10/snapshots/d3ded25a77fdef06de4059d94b080a34da6e7a82"
    )

    base_transformer_model = FluxTransformer2DModel.from_pretrained(
        transformer_path,
        torch_dtype=torch.bfloat16,
        use_safetensors=False
    ).to(memory_format=torch.channels_last)

    # Apply normalization quantization to transformer
    try:
        transformer_model = NormalizationQuantization(base_transformer_model, noise_level=0.03).apply()
    except Exception as e:
        print(f"Failed to apply normalization quantization on transformer model: {e}")
        transformer_model = base_transformer_model

    diffusion_pipeline = DiffusionPipeline.from_pretrained(
        model_checkpoint,
        revision=model_revision,
        vae=vae_model,
        transformer=transformer_model,
        text_encoder_2=text_encoder,
        torch_dtype=torch.bfloat16
    )
    diffusion_pipeline.to("cuda")

    for _ in range(3):
        diffusion_pipeline(
            prompt="freezable, catacorolla, gaiassa, unenkindled, grubs, solidiform",
            width=1024,
            height=1024,
            guidance_scale=0.0,
            num_inference_steps=4,
            max_sequence_length=256
        )

    return diffusion_pipeline

@torch.no_grad()
def perform_inference(request: TextToImageRequest, pipeline: pipeline_class) -> Image:
    generator = Generator(pipeline.device).manual_seed(request.seed)

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]