File size: 6,581 Bytes
14af353 5ed78ae 14af353 e20ee1a 5ed78ae 14af353 5ed78ae 14af353 5ed78ae 14af353 ffb19e3 14af353 3536ef2 aba977b e20ee1a 48d9237 5ed78ae cc5f991 e360995 cc5f991 e360995 cc5f991 e360995 cc5f991 14af353 5ed78ae 14af353 5ed78ae 14af353 e20ee1a 99da2cc 542791d e20ee1a e360995 542791d e360995 542791d e360995 542791d e360995 542791d c7874a4 4571d44 165f952 e20ee1a 4843bb2 165f952 4ed9fcf e20ee1a 181fdc1 e20ee1a cd02ff1 e20ee1a 14af353 5ed78ae 14af353 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
import diffusers
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only
import torch.nn.utils.prune as prune
import numpy as np
from tqdm import tqdm
from optimum.quanto import requantize
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
def load_quanto_transformer(repo_path):
with open(hf_hub_download(repo_path, "transformer/quantization_map.json"), "r") as f:
quantization_map = json.load(f)
with torch.device("meta"):
transformer = diffusers.FluxTransformer2DModel.from_config(hf_hub_download(repo_path, "transformer/config.json")).to(torch.bfloat16)
state_dict = load_file(hf_hub_download(repo_path, "transformer/diffusion_pytorch_model.safetensors"))
requantize(transformer, state_dict, quantization_map, device=torch.device("cuda"))
return transformer
def load_quanto_text_encoder_2(repo_path):
with open(hf_hub_download(repo_path, "text_encoder_2/quantization_map.json"), "r") as f:
quantization_map = json.load(f)
with open(hf_hub_download(repo_path, "text_encoder_2/config.json")) as f:
t5_config = transformers.T5Config(**json.load(f))
with torch.device("meta"):
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
state_dict = load_file(hf_hub_download(repo_path, "text_encoder_2/model.safetensors"))
requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda"))
return text_encoder_2
torch._dynamo.config.suppress_errors = True
Pipeline = None
def weight_svd_prune(module, threshold_ratio=0.5):
w = module.weight.data.cpu().float().numpy() # Convert to float32 before numpy conversion
w = w.reshape(w.shape[0], -1) # Reshape for SVD
U, S, V = np.linalg.svd(w, full_matrices=False)
k = int(len(S) * (1 - threshold_ratio)) #Keep top k singular values
S_mask = np.zeros_like(S)
S_mask[:k] = 1
S_masked = S * S_mask
w_pruned = np.dot(np.dot(U, np.diag(S_masked)), V)
w_pruned = w_pruned.reshape(w.shape)
module.weight.data = torch.tensor(w_pruned, dtype=module.weight.data.dtype).to(module.weight.data.device)
ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
start = time.time()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> Pipeline:
empty_cache()
dtype, device = torch.bfloat16, "cuda"
# vae = AutoencoderKL.from_pretrained(
# ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16
# )
# quantize_(vae, int8_weight_only())
# prune.l1_unstructured(module, 'weight', amount=0.1)
# pipeline = DiffusionPipeline.from_pretrained(
# ckpt_id,
# vae=vae,
# torch_dtype=dtype,
# )
# quantize_(pipeline.text_encoder, int8_weight_only())
# # List comprehensions to get devices of modules
# cpu_modules = [name for name, module in pipeline.transformer.named_modules() if next(module.parameters(), None) is None or next(module.parameters()).device == torch.device('cpu')]
# cuda_modules = [name for name, module in pipeline.transformer.named_modules() if next(module.parameters(), None) is not None and next(module.parameters()).device.type == 'cuda']
# # Total number of modules
# total_modules = len(list(pipeline.transformer.named_modules()))
# # Calculate percentages
# cpu_percentage = len(cpu_modules) / total_modules * 100 if total_modules > 0 else 0
# cuda_percentage = len(cuda_modules) / total_modules * 100 if total_modules > 0 else 0
# # Print the results
# print(f"Modules on CPU: {len(cpu_modules)} ({cpu_percentage:.2f}%)")
# print(f"Modules on CUDA: {len(cuda_modules)} ({cuda_percentage:.2f}%)")
# counter = 0
# for name, module in pipeline.transformer.named_modules():
# if counter> 3:
# break
# if isinstance(module, torch.nn.Linear):
# prune.random_unstructured(module, name="weight", amount=0.2)
# counter+=1
# print("pruning the weights")
# for name, module in tqdm(pipeline.transformer.named_modules()):
# if isinstance(module, torch.nn.Linear):
# weight_svd_prune(module, threshold_ratio=0.2)
# print("weights pruned")
# backend = "torch_tensorrt"
# print("compile")
# print(torch._dynamo.list_backends())
# # Optimize the UNet portion with Torch-TensorRT
# pipeline.transformer = torch.compile(
# pipeline.transformer,
# backend=backend,
# options={
# "truncate_long_and_double": True,
# "enabled_precisions": {torch.bfloat16},
# },
# dynamic=False,
# )
# pipeline.transformer = torch.compile(pipeline.transformer, backend="inductor", mode='max-autotune-no-cudagraphs', dynamic=True)
# print("no compile")
# pipeline.to(device)
pipeline = diffusers.AutoPipelineForText2Image.from_pretrained(ckpt_id, transformer=None, text_encoder_2=None, torch_dtype=torch.bfloat16)
pipeline.transformer = load_quanto_transformer("Disty0/FLUX.1-dev-qint8")
pipeline.text_encoder_2 = load_quanto_text_encoder_2("Disty0/FLUX.1-dev-qint8")
pipeline = pipeline.to(dtype=torch.bfloat16)
## warm up
pipeline.enable_sequential_cpu_offload()
for _ in range(2):
empty_cache()
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
return pipeline
from datetime import datetime
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
empty_cache()
generator = Generator("cuda").manual_seed(request.seed)
image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
return image |