simpleflux4 / src /pipeline.py
manbeast3b's picture
Update src/pipeline.py
cd02ff1 verified
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
import diffusers
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only
import torch.nn.utils.prune as prune
import numpy as np
from tqdm import tqdm
from optimum.quanto import requantize
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
def load_quanto_transformer(repo_path):
with open(hf_hub_download(repo_path, "transformer/quantization_map.json"), "r") as f:
quantization_map = json.load(f)
with torch.device("meta"):
transformer = diffusers.FluxTransformer2DModel.from_config(hf_hub_download(repo_path, "transformer/config.json")).to(torch.bfloat16)
state_dict = load_file(hf_hub_download(repo_path, "transformer/diffusion_pytorch_model.safetensors"))
requantize(transformer, state_dict, quantization_map, device=torch.device("cuda"))
return transformer
def load_quanto_text_encoder_2(repo_path):
with open(hf_hub_download(repo_path, "text_encoder_2/quantization_map.json"), "r") as f:
quantization_map = json.load(f)
with open(hf_hub_download(repo_path, "text_encoder_2/config.json")) as f:
t5_config = transformers.T5Config(**json.load(f))
with torch.device("meta"):
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
state_dict = load_file(hf_hub_download(repo_path, "text_encoder_2/model.safetensors"))
requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda"))
return text_encoder_2
torch._dynamo.config.suppress_errors = True
Pipeline = None
def weight_svd_prune(module, threshold_ratio=0.5):
w = module.weight.data.cpu().float().numpy() # Convert to float32 before numpy conversion
w = w.reshape(w.shape[0], -1) # Reshape for SVD
U, S, V = np.linalg.svd(w, full_matrices=False)
k = int(len(S) * (1 - threshold_ratio)) #Keep top k singular values
S_mask = np.zeros_like(S)
S_mask[:k] = 1
S_masked = S * S_mask
w_pruned = np.dot(np.dot(U, np.diag(S_masked)), V)
w_pruned = w_pruned.reshape(w.shape)
module.weight.data = torch.tensor(w_pruned, dtype=module.weight.data.dtype).to(module.weight.data.device)
ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
start = time.time()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> Pipeline:
empty_cache()
dtype, device = torch.bfloat16, "cuda"
# vae = AutoencoderKL.from_pretrained(
# ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16
# )
# quantize_(vae, int8_weight_only())
# prune.l1_unstructured(module, 'weight', amount=0.1)
# pipeline = DiffusionPipeline.from_pretrained(
# ckpt_id,
# vae=vae,
# torch_dtype=dtype,
# )
# quantize_(pipeline.text_encoder, int8_weight_only())
# # List comprehensions to get devices of modules
# cpu_modules = [name for name, module in pipeline.transformer.named_modules() if next(module.parameters(), None) is None or next(module.parameters()).device == torch.device('cpu')]
# cuda_modules = [name for name, module in pipeline.transformer.named_modules() if next(module.parameters(), None) is not None and next(module.parameters()).device.type == 'cuda']
# # Total number of modules
# total_modules = len(list(pipeline.transformer.named_modules()))
# # Calculate percentages
# cpu_percentage = len(cpu_modules) / total_modules * 100 if total_modules > 0 else 0
# cuda_percentage = len(cuda_modules) / total_modules * 100 if total_modules > 0 else 0
# # Print the results
# print(f"Modules on CPU: {len(cpu_modules)} ({cpu_percentage:.2f}%)")
# print(f"Modules on CUDA: {len(cuda_modules)} ({cuda_percentage:.2f}%)")
# counter = 0
# for name, module in pipeline.transformer.named_modules():
# if counter> 3:
# break
# if isinstance(module, torch.nn.Linear):
# prune.random_unstructured(module, name="weight", amount=0.2)
# counter+=1
# print("pruning the weights")
# for name, module in tqdm(pipeline.transformer.named_modules()):
# if isinstance(module, torch.nn.Linear):
# weight_svd_prune(module, threshold_ratio=0.2)
# print("weights pruned")
# backend = "torch_tensorrt"
# print("compile")
# print(torch._dynamo.list_backends())
# # Optimize the UNet portion with Torch-TensorRT
# pipeline.transformer = torch.compile(
# pipeline.transformer,
# backend=backend,
# options={
# "truncate_long_and_double": True,
# "enabled_precisions": {torch.bfloat16},
# },
# dynamic=False,
# )
# pipeline.transformer = torch.compile(pipeline.transformer, backend="inductor", mode='max-autotune-no-cudagraphs', dynamic=True)
# print("no compile")
# pipeline.to(device)
pipeline = diffusers.AutoPipelineForText2Image.from_pretrained(ckpt_id, transformer=None, text_encoder_2=None, torch_dtype=torch.bfloat16)
pipeline.transformer = load_quanto_transformer("Disty0/FLUX.1-dev-qint8")
pipeline.text_encoder_2 = load_quanto_text_encoder_2("Disty0/FLUX.1-dev-qint8")
pipeline = pipeline.to(dtype=torch.bfloat16)
## warm up
pipeline.enable_sequential_cpu_offload()
for _ in range(2):
empty_cache()
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
return pipeline
from datetime import datetime
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
empty_cache()
generator = Generator("cuda").manual_seed(request.seed)
image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
return image