from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import FlowMatchEulerDiscreteScheduler import diffusers from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel import torch import torch._dynamo import gc from PIL import Image as img from PIL.Image import Image from pipelines.models import TextToImageRequest from torch import Generator import time from diffusers import FluxTransformer2DModel, DiffusionPipeline from torchao.quantization import quantize_, int8_weight_only import torch.nn.utils.prune as prune import numpy as np from tqdm import tqdm from optimum.quanto import requantize from safetensors.torch import load_file from huggingface_hub import hf_hub_download def load_quanto_transformer(repo_path): with open(hf_hub_download(repo_path, "transformer/quantization_map.json"), "r") as f: quantization_map = json.load(f) with torch.device("meta"): transformer = diffusers.FluxTransformer2DModel.from_config(hf_hub_download(repo_path, "transformer/config.json")).to(torch.bfloat16) state_dict = load_file(hf_hub_download(repo_path, "transformer/diffusion_pytorch_model.safetensors")) requantize(transformer, state_dict, quantization_map, device=torch.device("cuda")) return transformer def load_quanto_text_encoder_2(repo_path): with open(hf_hub_download(repo_path, "text_encoder_2/quantization_map.json"), "r") as f: quantization_map = json.load(f) with open(hf_hub_download(repo_path, "text_encoder_2/config.json")) as f: t5_config = transformers.T5Config(**json.load(f)) with torch.device("meta"): text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16) state_dict = load_file(hf_hub_download(repo_path, "text_encoder_2/model.safetensors")) requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda")) return text_encoder_2 torch._dynamo.config.suppress_errors = True Pipeline = None def weight_svd_prune(module, threshold_ratio=0.5): w = module.weight.data.cpu().float().numpy() # Convert to float32 before numpy conversion w = w.reshape(w.shape[0], -1) # Reshape for SVD U, S, V = np.linalg.svd(w, full_matrices=False) k = int(len(S) * (1 - threshold_ratio)) #Keep top k singular values S_mask = np.zeros_like(S) S_mask[:k] = 1 S_masked = S * S_mask w_pruned = np.dot(np.dot(U, np.diag(S_masked)), V) w_pruned = w_pruned.reshape(w.shape) module.weight.data = torch.tensor(w_pruned, dtype=module.weight.data.dtype).to(module.weight.data.device) ckpt_id = "black-forest-labs/FLUX.1-schnell" def empty_cache(): start = time.time() gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def load_pipeline() -> Pipeline: empty_cache() dtype, device = torch.bfloat16, "cuda" # vae = AutoencoderKL.from_pretrained( # ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16 # ) # quantize_(vae, int8_weight_only()) # prune.l1_unstructured(module, 'weight', amount=0.1) # pipeline = DiffusionPipeline.from_pretrained( # ckpt_id, # vae=vae, # torch_dtype=dtype, # ) # quantize_(pipeline.text_encoder, int8_weight_only()) # # List comprehensions to get devices of modules # cpu_modules = [name for name, module in pipeline.transformer.named_modules() if next(module.parameters(), None) is None or next(module.parameters()).device == torch.device('cpu')] # cuda_modules = [name for name, module in pipeline.transformer.named_modules() if next(module.parameters(), None) is not None and next(module.parameters()).device.type == 'cuda'] # # Total number of modules # total_modules = len(list(pipeline.transformer.named_modules())) # # Calculate percentages # cpu_percentage = len(cpu_modules) / total_modules * 100 if total_modules > 0 else 0 # cuda_percentage = len(cuda_modules) / total_modules * 100 if total_modules > 0 else 0 # # Print the results # print(f"Modules on CPU: {len(cpu_modules)} ({cpu_percentage:.2f}%)") # print(f"Modules on CUDA: {len(cuda_modules)} ({cuda_percentage:.2f}%)") # counter = 0 # for name, module in pipeline.transformer.named_modules(): # if counter> 3: # break # if isinstance(module, torch.nn.Linear): # prune.random_unstructured(module, name="weight", amount=0.2) # counter+=1 # print("pruning the weights") # for name, module in tqdm(pipeline.transformer.named_modules()): # if isinstance(module, torch.nn.Linear): # weight_svd_prune(module, threshold_ratio=0.2) # print("weights pruned") # backend = "torch_tensorrt" # print("compile") # print(torch._dynamo.list_backends()) # # Optimize the UNet portion with Torch-TensorRT # pipeline.transformer = torch.compile( # pipeline.transformer, # backend=backend, # options={ # "truncate_long_and_double": True, # "enabled_precisions": {torch.bfloat16}, # }, # dynamic=False, # ) # pipeline.transformer = torch.compile(pipeline.transformer, backend="inductor", mode='max-autotune-no-cudagraphs', dynamic=True) # print("no compile") # pipeline.to(device) pipeline = diffusers.AutoPipelineForText2Image.from_pretrained(ckpt_id, transformer=None, text_encoder_2=None, torch_dtype=torch.bfloat16) pipeline.transformer = load_quanto_transformer("Disty0/FLUX.1-dev-qint8") pipeline.text_encoder_2 = load_quanto_text_encoder_2("Disty0/FLUX.1-dev-qint8") pipeline = pipeline.to(dtype=torch.bfloat16) ## warm up pipeline.enable_sequential_cpu_offload() for _ in range(2): empty_cache() pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256) return pipeline from datetime import datetime @torch.inference_mode() def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: empty_cache() generator = Generator("cuda").manual_seed(request.seed) image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0] return image