File size: 6,581 Bytes
14af353
5ed78ae
14af353
e20ee1a
5ed78ae
 
14af353
5ed78ae
14af353
5ed78ae
 
 
14af353
 
ffb19e3
14af353
3536ef2
aba977b
e20ee1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48d9237
5ed78ae
 
cc5f991
 
 
e360995
 
cc5f991
e360995
 
 
cc5f991
e360995
 
 
cc5f991
 
14af353
5ed78ae
14af353
5ed78ae
 
 
 
 
14af353
 
 
e20ee1a
 
 
 
99da2cc
542791d
e20ee1a
 
 
 
 
 
 
e360995
 
 
542791d
e360995
 
542791d
e360995
 
 
542791d
e360995
 
 
 
 
 
 
 
 
 
 
542791d
c7874a4
 
 
 
 
4571d44
165f952
e20ee1a
4843bb2
165f952
 
 
 
 
 
 
 
 
 
4ed9fcf
e20ee1a
181fdc1
e20ee1a
cd02ff1
e20ee1a
 
 
 
 
 
14af353
 
 
 
 
 
 
 
 
 
 
5ed78ae
14af353
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
import diffusers
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only
import torch.nn.utils.prune as prune
import numpy as np
from tqdm import tqdm
from optimum.quanto import requantize
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

def load_quanto_transformer(repo_path):
    with open(hf_hub_download(repo_path, "transformer/quantization_map.json"), "r") as f:
        quantization_map = json.load(f)
    with torch.device("meta"):
        transformer = diffusers.FluxTransformer2DModel.from_config(hf_hub_download(repo_path, "transformer/config.json")).to(torch.bfloat16)
    state_dict = load_file(hf_hub_download(repo_path, "transformer/diffusion_pytorch_model.safetensors"))
    requantize(transformer, state_dict, quantization_map, device=torch.device("cuda"))
    return transformer


def load_quanto_text_encoder_2(repo_path):
    with open(hf_hub_download(repo_path, "text_encoder_2/quantization_map.json"), "r") as f:
        quantization_map = json.load(f)
    with open(hf_hub_download(repo_path, "text_encoder_2/config.json")) as f:
        t5_config = transformers.T5Config(**json.load(f))
    with torch.device("meta"):
        text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
    state_dict = load_file(hf_hub_download(repo_path, "text_encoder_2/model.safetensors"))
    requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda"))
    return text_encoder_2
    
torch._dynamo.config.suppress_errors = True
Pipeline = None

def weight_svd_prune(module, threshold_ratio=0.5):
    w = module.weight.data.cpu().float().numpy() # Convert to float32 before numpy conversion
    w = w.reshape(w.shape[0], -1)  # Reshape for SVD
    U, S, V = np.linalg.svd(w, full_matrices=False)

    k = int(len(S) * (1 - threshold_ratio))  #Keep top k singular values
    S_mask = np.zeros_like(S)
    S_mask[:k] = 1
    S_masked = S * S_mask

    w_pruned = np.dot(np.dot(U, np.diag(S_masked)), V)
    w_pruned = w_pruned.reshape(w.shape)

    module.weight.data = torch.tensor(w_pruned, dtype=module.weight.data.dtype).to(module.weight.data.device)
    
ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
    start = time.time()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def load_pipeline() -> Pipeline:    
    empty_cache()
    dtype, device = torch.bfloat16, "cuda"
    # vae = AutoencoderKL.from_pretrained(
    #     ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16
    # )
    # quantize_(vae, int8_weight_only())
    
    # prune.l1_unstructured(module, 'weight', amount=0.1)
    # pipeline = DiffusionPipeline.from_pretrained(
    #     ckpt_id,
    #     vae=vae,
    #     torch_dtype=dtype,
    #     )
    
    # quantize_(pipeline.text_encoder, int8_weight_only())
    # # List comprehensions to get devices of modules
    # cpu_modules = [name for name, module in pipeline.transformer.named_modules() if next(module.parameters(), None) is None or next(module.parameters()).device == torch.device('cpu')]
    # cuda_modules = [name for name, module in pipeline.transformer.named_modules() if next(module.parameters(), None) is not None and next(module.parameters()).device.type == 'cuda']
    
    # # Total number of modules
    # total_modules = len(list(pipeline.transformer.named_modules()))
    
    # # Calculate percentages
    # cpu_percentage = len(cpu_modules) / total_modules * 100 if total_modules > 0 else 0
    # cuda_percentage = len(cuda_modules) / total_modules * 100 if total_modules > 0 else 0
    
    # # Print the results
    # print(f"Modules on CPU: {len(cpu_modules)} ({cpu_percentage:.2f}%)")
    # print(f"Modules on CUDA: {len(cuda_modules)} ({cuda_percentage:.2f}%)")

    # counter = 0
    # for name, module in pipeline.transformer.named_modules():
    #   if counter> 3:
    #       break
    #   if isinstance(module, torch.nn.Linear):
    #       prune.random_unstructured(module, name="weight", amount=0.2)
    #       counter+=1

    # print("pruning the weights")
    # for name, module in tqdm(pipeline.transformer.named_modules()):
    #     if isinstance(module, torch.nn.Linear):
    #         weight_svd_prune(module, threshold_ratio=0.2)
    # print("weights pruned")

    # backend = "torch_tensorrt"
    # print("compile")
    # print(torch._dynamo.list_backends())
    # # Optimize the UNet portion with Torch-TensorRT
    # pipeline.transformer = torch.compile(
    #     pipeline.transformer,
    #     backend=backend,
    #     options={
    #         "truncate_long_and_double": True,
    #         "enabled_precisions": {torch.bfloat16},
    #     },
    #     dynamic=False,
    # )
    # pipeline.transformer = torch.compile(pipeline.transformer, backend="inductor", mode='max-autotune-no-cudagraphs', dynamic=True)
    # print("no compile")
    # pipeline.to(device)

    pipeline = diffusers.AutoPipelineForText2Image.from_pretrained(ckpt_id, transformer=None, text_encoder_2=None, torch_dtype=torch.bfloat16)
    pipeline.transformer = load_quanto_transformer("Disty0/FLUX.1-dev-qint8")
    pipeline.text_encoder_2 = load_quanto_text_encoder_2("Disty0/FLUX.1-dev-qint8")
    pipeline = pipeline.to(dtype=torch.bfloat16)
    
    ## warm up

    pipeline.enable_sequential_cpu_offload()
    for _ in range(2):
        empty_cache()
        pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    return pipeline


from datetime import datetime
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    empty_cache()
    generator = Generator("cuda").manual_seed(request.seed)
    image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
    return image