flux-quant-4-cuda / src /pipeline.py
jokerbit's picture
Even less
ccd236e verified
import os
from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel
from diffusers.image_processor import VaeImageProcessor
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel, CLIPTextConfig, T5Config
import torch
import gc
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only
from time import perf_counter
HOME = os.environ["HOME"]
FLUX_CHECKPOINT = os.path.join(HOME,
".cache/huggingface/hub/models--black-forest-labs--FLUX.1-schnell/snapshots/741f7c3ce8b383c54771c7003378a50191e9efe9/")
QUANTIZED_MODEL = ["transformer"]
QUANT_CONFIG = int8_weight_only()
DTYPE = torch.bfloat16
NUM_STEPS = 4
def get_transformer(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
if quant_ckpt is not None:
config = FluxTransformer2DModel.load_config(FLUX_CHECKPOINT, subfolder="transformer", local_files_only=True)
model = FluxTransformer2DModel.from_config(config).to(DTYPE)
state_dict = torch.load(quant_ckpt, map_location="cpu")
model.load_state_dict(state_dict, assign=True)
print(f"Loaded {quant_ckpt}")
return model
model = FluxTransformer2DModel.from_pretrained(
FLUX_CHECKPOINT, subfolder="transformer", torch_dtype=DTYPE, local_files_only=True
)
if quantize:
quantize_(model, quant_config)
return model
def get_text_encoder(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
if quant_ckpt is not None:
config = CLIPTextConfig.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder", local_files_only=True)
model = CLIPTextModel(config).to(DTYPE)
state_dict = torch.load(quant_ckpt, map_location="cpu")
model.load_state_dict(state_dict, assign=True)
print(f"Loaded {quant_ckpt}")
return model
model = CLIPTextModel.from_pretrained(
FLUX_CHECKPOINT, subfolder="text_encoder", torch_dtype=DTYPE, local_files_only=True
)
if quantize:
quantize_(model, quant_config)
return model
def get_text_encoder_2(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
if quant_ckpt is not None:
config = T5Config.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder_2", local_files_only=True)
model = T5EncoderModel(config).to(DTYPE)
state_dict = torch.load(quant_ckpt, map_location="cpu")
print(f"Loaded {quant_ckpt}")
model.load_state_dict(state_dict, assign=True)
return model
model = T5EncoderModel.from_pretrained(
FLUX_CHECKPOINT, subfolder="text_encoder_2", torch_dtype=DTYPE, local_files_only=True
)
if quantize:
quantize_(model, quant_config)
return model
def get_vae(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
if quant_ckpt is not None:
config = AutoencoderKL.load_config(FLUX_CHECKPOINT, subfolder="vae", local_files_only=True)
model = AutoencoderKL.from_config(config).to(DTYPE)
state_dict = torch.load(quant_ckpt, map_location="cpu")
model.load_state_dict(state_dict, assign=True)
print(f"Loaded {quant_ckpt}")
return model
model = AutoencoderKL.from_pretrained(
FLUX_CHECKPOINT, subfolder="vae", torch_dtype=DTYPE, local_files_only=True
)
if quantize:
quantize_(model, quant_config)
return model
def empty_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> FluxPipeline:
empty_cache()
transformer = get_transformer('transformer' in QUANTIZED_MODEL, QUANT_CONFIG)
text_encoder = get_text_encoder("text_encoder" in QUANTIZED_MODEL, QUANT_CONFIG)
text_encoder_2 = get_text_encoder_2("text_encoder_2" in QUANTIZED_MODEL, QUANT_CONFIG)
vae = get_vae("vae" in QUANTIZED_MODEL, QUANT_CONFIG)
pipe = FluxPipeline.from_pretrained(FLUX_CHECKPOINT,
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16).to("cuda")
# pipe.transformer = torch.compile(pipe.transformer, backend="cudagraphs")
pipe.enable_model_cpu_offload()
# empty_cache()
pipe("cat", guidance_scale=0., max_sequence_length=256, num_inference_steps=4)
return pipe
def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image:
if request.seed is None:
generator = None
else:
generator = Generator(device="cuda").manual_seed(request.seed)
# empty_cache()
image = _pipeline(prompt=request.prompt,
width=request.width,
height=request.height,
guidance_scale=0.0,
generator=generator,
output_type="pil",
max_sequence_length=256,
num_inference_steps=NUM_STEPS).images[0]
return image
if __name__ == "__main__":
start_time = perf_counter()
pipe = load_pipeline()
stop_time = perf_counter()
print(f"Pipeline is loaded in {stop_time - start_time}s")
pipe("cat holding a womboai sign", num_inference_steps=4, guidance_scale=0, generator=torch.Generator(pipe.device).manual_seed(666)).images[0].save("sample.png")