File size: 3,482 Bytes
f65f896
 
 
 
 
 
 
7a37b0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f65f896
7a37b0b
f65f896
 
7a37b0b
 
 
 
 
f65f896
7a37b0b
f65f896
 
 
7a37b0b
 
f65f896
7a37b0b
 
 
 
 
f65f896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a37b0b
f65f896
 
7a37b0b
f65f896
 
 
 
 
 
 
 
 
7a37b0b
 
f65f896
7a37b0b
f65f896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import gc
import torch
import numpy as np
from PIL import Image
from typing import Optional

from diffusers import (
    DiffusionPipeline,
    AutoencoderKL,
    FluxPipeline,
    FluxTransformer2DModel
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import (
    T5EncoderModel,
    T5TokenizerFast,
    CLIPTokenizer,
    CLIPTextModel
)
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only

# Pre-configurations
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True

# Global variables
Pipeline = None
CKPT_ID = "black-forest-labs/FLUX.1-schnell"
CKPT_REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"


def empty_cache():
    """Utility function to clear GPU memory."""
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


def load_pipeline() -> FluxPipeline:
    """Loads the diffusion pipeline with specified models and configurations."""
    # Load text encoder
    text_encoder_2 = T5EncoderModel.from_pretrained(
        "Chrissy1/extra0manQ0",
        revision="c0db1e82d89825a4664ad873f20d261cbe46e737",
        subfolder="text_encoder_2",
        torch_dtype=torch.bfloat16
    ).to(memory_format=torch.channels_last)

    # Load transformer
    transformer_path = os.path.join(
        HF_HUB_CACHE,
        "models--Chrissy1--extra0manQ0/snapshots/c0db1e82d89825a4664ad873f20d261cbe46e737/transformer"
    )
    transformer = FluxTransformer2DModel.from_pretrained(
        transformer_path, 
        torch_dtype=torch.bfloat16, 
        use_safetensors=False
    ).to(memory_format=torch.channels_last)

    # Load and quantize autoencoder
    vae = AutoencoderKL.from_pretrained(
        CKPT_ID,
        revision=CKPT_REVISION,
        subfolder="vae",
        local_files_only=True,
        torch_dtype=torch.bfloat16
    )
    quantize_(vae, int8_weight_only())

    # Load FluxPipeline
    pipeline = FluxPipeline.from_pretrained(
        CKPT_ID,
        revision=CKPT_REVISION,
        transformer=transformer,
        text_encoder_2=text_encoder_2,
        torch_dtype=torch.bfloat16
    )
    pipeline.to("cuda")

    # Warm-up run to ensure the pipeline is ready
    with torch.inference_mode():
        pipeline(
            prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus",
            width=1024,
            height=1024,
            guidance_scale=0.0,
            num_inference_steps=4,
            max_sequence_length=256
        )

    return pipeline


@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: FluxPipeline, generator: Generator) -> Image:
    """Generates an image based on the input request and pipeline."""
    empty_cache()  # Clear cache before inference
    
    result = pipeline(
        prompt=request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
        output_type="pil"
    )
    
    return result.images[0]