File size: 4,372 Bytes
33100d5
 
 
796bc4c
33100d5
 
 
 
 
 
 
 
 
 
796bc4c
33100d5
38fc09e
 
 
 
 
 
33100d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38fc09e
 
 
 
 
e18b8f9
38fc09e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e18b8f9
38fc09e
 
33100d5
 
 
 
71265b3
 
 
 
 
 
 
 
 
 
 
 
33100d5
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from torch.ao.quantization import quantize_dynamic
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
# from torchao.quantization import quantize_,int8_weight_only
import os

from torch.ao.quantization import prepare, convert
from torch.ao.quantization import QConfig
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.quantize import quantize_dynamic

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
Pipeline = None

ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
    start = time.time()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()
    print(f"Flush took: {time.time() - start}")

def load_pipeline() -> Pipeline:    
    empty_cache()
    dtype, device = torch.bfloat16, "cuda"

    text_encoder_2 = T5EncoderModel.from_pretrained(
        "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
    )
    vae=AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype)
    pipeline = DiffusionPipeline.from_pretrained(
        ckpt_id,
        vae=vae,
        text_encoder_2 = text_encoder_2,
        torch_dtype=dtype,
        )
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.cuda.set_per_process_memory_fraction(0.99)
    pipeline.text_encoder.to(memory_format=torch.channels_last)
    # pipeline.transformer.to(memory_format=torch.channels_last)
    # quantize_dynamic(pipeline.transformer, dtype=torch.float8_e5m2fnuz, inplace=True)

    
    # Define a custom qconfig for float8_e5m2fnuz
    float8_observer = MinMaxObserver.with_args(dtype=torch.qint8)
    custom_qconfig = QConfig(
        activation=float8_observer,
        weight=float8_observer
    )
    qconfig_spec = {
        "linear": custom_qconfig,
        "linear_1": custom_qconfig,
        "linear_2": custom_qconfig,
        "to_q": custom_qconfig,
        "to_k": custom_qconfig,
        "to_v": custom_qconfig,
        "add_k_proj": custom_qconfig,
        "add_v_proj": custom_qconfig,
        "add_q_proj": custom_qconfig,
        "proj": custom_qconfig,
        "proj_mlp": custom_qconfig,
        "proj_out": custom_qconfig
    }

    # Apply dynamic quantization to Transformer
    pipeline.transformer = quantize_dynamic(
        pipeline.transformer,
        qconfig_spec=qconfig_spec,  # Apply qconfig only to transformer layers
        dtype=torch.qint8, #torch.float8_e5m2fnuz
        inplace=True,
    )

    pipeline.vae.to(memory_format=torch.channels_last)
    pipeline.vae = torch.compile(pipeline.vae)
    
    pipeline._exclude_from_cpu_offload = ["vae"]
    # pipeline.enable_sequential_cpu_offload()
    def custom_cpu_offload(model, device, offload_buffers=True):
        state_dict = model.state_dict()
        filtered_state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
        for name, param in filtered_state_dict.items():
            param.data = param.to(device)

    custom_cpu_offload(pipeline.text_encoder, "cpu")
    custom_cpu_offload(pipeline.text_encoder_2, "cpu")
    custom_cpu_offload(pipeline.transformer, "cpu")

    for _ in range(2):
        pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    return pipeline


@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    torch.cuda.reset_peak_memory_stats()
    generator = Generator("cuda").manual_seed(request.seed)
    image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
    return(image)