flux-quant-13-cuda / src /pipeline.py
jokerbit's picture
Upload folder using huggingface_hub
66c03f6 verified
import os
from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel
from diffusers.image_processor import VaeImageProcessor
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel, CLIPTextConfig, T5Config
import torch
import gc
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only
from time import perf_counter
HOME = os.environ["HOME"]
QUANTIZED_MODEL = ["vae"]
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
FLUX_CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
FLUX_CACHE = os.path.join(HOME, ".cache/huggingface/hub/models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.99)
QUANT_CONFIG = int8_weight_only()
DTYPE = torch.bfloat16
NUM_STEPS = 4
PROMPT = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
def empty_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def quantize(pipe, config):
if "text_encoder" in QUANTIZED_MODEL:
quantize_(pipe.text_encoder, config)
if "text_encoder_2" in QUANTIZED_MODEL:
quantize_(pipe.text_encoder_2, config)
if "transformer" in QUANTIZED_MODEL:
quantize_(pipe.transformer, config, device="cuda")
if "vae" in QUANTIZED_MODEL:
quantize_(pipe.vae, config)
return pipe
def load_pipeline() -> FluxPipeline:
empty_cache()
transformer = FluxTransformer2DModel.from_pretrained(os.path.join(FLUX_CACHE, "transformer"), use_safetensors=False, torch_dtype=DTYPE)
pipe = FluxPipeline.from_pretrained(FLUX_CHECKPOINT,
transformer=transformer,
torch_dtype=DTYPE)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
quantize(pipe, QUANT_CONFIG)
pipe.vae = torch.compile(pipe.vae)
for _ in range(4):
request = TextToImageRequest(prompt=PROMPT, height=1024, width=1024, seed=666)
infer(request, pipe)
return pipe
def encode_prompt(_pipeline, prompt: str):
pipeline = FluxPipeline.from_pipe(
_pipeline,
transformer=None,
vae=None,
).to("cuda")
with torch.inference_mode():
outputs = pipeline.encode_prompt(
prompt=prompt,
prompt_2=None,
max_sequence_length=256)
return outputs
def infer_latents(_pipeline, prompt_embeds, pooled_prompt_embeds, width: int | None, height: int | None, seed: int | None):
pipeline = FluxPipeline.from_pipe(
_pipeline,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
).to("cuda")
if seed is None:
generator = None
else:
generator = Generator(pipeline.device).manual_seed(seed)
outputs = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=4,
guidance_scale=0.0,
width=width,
height=height,
generator=generator,
output_type="latent",
).images
return outputs
def decode_latents(vae, latents, width, height):
vae.to("cuda")
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
width = width or 64 * vae_scale_factor
height = height or 64 * vae_scale_factor
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
with torch.inference_mode():
image = vae.decode(latents, return_dict=False)[0]
return image_processor.postprocess(image, output_type="pil")[0]
# @torch.inference_mode()
def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(_pipeline, request.prompt)
# _pipeline.text_encoder_2.encoder.to("cpu")
_pipeline.text_encoder.to("cpu")
latents = infer_latents(_pipeline, prompt_embeds, pooled_prompt_embeds, request.width, request.height, request.seed)
del prompt_embeds
del pooled_prompt_embeds
del text_ids
_pipeline.transformer.single_transformer_blocks.to("cpu")
_pipeline.transformer.transformer_blocks.to("cpu")
image = decode_latents(_pipeline.vae, latents, request.width, request.height)
torch.cuda.reset_peak_memory_stats()
_pipeline.vae.to("cpu")
return image
# def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image:
# if request.seed is None:
# generator = None
# else:
# generator = Generator(device="cuda").manual_seed(request.seed)
# torch.cuda.reset_peak_memory_stats()
# image = _pipeline(prompt=request.prompt,
# width=request.width,
# height=request.height,
# guidance_scale=0.0,
# generator=generator,
# output_type="pil",
# max_sequence_length=256,
# num_inference_steps=NUM_STEPS).images[0]
# return image
if __name__ == "__main__":
request = TextToImageRequest(prompt=PROMPT,
height=None,
width=None,
seed=666)
start_time = perf_counter()
pipe_ = load_pipeline()
stop_time = perf_counter()
print(f"Pipeline is loaded in {stop_time - start_time}s")
for _ in range(4):
start_time = perf_counter()
infer(request, pipe_)
stop_time = perf_counter()
print(f"Request in {stop_time - start_time}s")
# pipe("cat holding a womboai sign", num_inference_steps=4, guidance_scale=0, generator=torch.Generator(pipe.device).manual_seed(666)).images[0].save("sample.png")