|
|
import os |
|
|
from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel |
|
|
from diffusers.image_processor import VaeImageProcessor |
|
|
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel, CLIPTextConfig, T5Config |
|
|
import torch |
|
|
import gc |
|
|
from PIL.Image import Image |
|
|
from pipelines.models import TextToImageRequest |
|
|
from torch import Generator |
|
|
from torchao.quantization import quantize_, int8_weight_only |
|
|
from time import perf_counter |
|
|
|
|
|
|
|
|
HOME = os.environ["HOME"] |
|
|
QUANTIZED_MODEL = ["text_encoder_2", "vae"] |
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01" |
|
|
FLUX_CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo" |
|
|
FLUX_CACHE = os.path.join(HOME, ".cache/huggingface/hub/models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b") |
|
|
torch.backends.cudnn.benchmark = True |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.cuda.set_per_process_memory_fraction(0.99) |
|
|
|
|
|
QUANT_CONFIG = int8_weight_only() |
|
|
DTYPE = torch.bfloat16 |
|
|
NUM_STEPS = 4 |
|
|
PROMPT = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle' |
|
|
|
|
|
|
|
|
def empty_cache(): |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_max_memory_allocated() |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
|
|
|
|
|
def quantize(pipe, config): |
|
|
if "text_encoder" in QUANTIZED_MODEL: |
|
|
quantize_(pipe.text_encoder, config) |
|
|
if "text_encoder_2" in QUANTIZED_MODEL: |
|
|
quantize_(pipe.text_encoder_2, config) |
|
|
if "transformer" in QUANTIZED_MODEL: |
|
|
quantize_(pipe.transformer, config, device="cuda") |
|
|
if "vae" in QUANTIZED_MODEL: |
|
|
quantize_(pipe.vae, config) |
|
|
return pipe |
|
|
|
|
|
|
|
|
def load_pipeline() -> FluxPipeline: |
|
|
empty_cache() |
|
|
transformer = FluxTransformer2DModel.from_pretrained(os.path.join(FLUX_CACHE, "transformer"), use_safetensors=False, torch_dtype=DTYPE) |
|
|
pipe = FluxPipeline.from_pretrained(FLUX_CHECKPOINT, |
|
|
transformer=transformer, |
|
|
torch_dtype=DTYPE) |
|
|
pipe.vae.enable_tiling() |
|
|
pipe.vae.enable_slicing() |
|
|
quantize(pipe, QUANT_CONFIG) |
|
|
pipe.to("cuda") |
|
|
request = TextToImageRequest(prompt=PROMPT, height=1024, width=1024, seed=666) |
|
|
infer(request, pipe) |
|
|
|
|
|
return pipe |
|
|
|
|
|
|
|
|
def encode_prompt(_pipeline, prompt: str): |
|
|
pipeline = FluxPipeline.from_pipe( |
|
|
_pipeline, |
|
|
transformer=None, |
|
|
vae=None, |
|
|
).to("cuda") |
|
|
with torch.no_grad(): |
|
|
outputs = pipeline.encode_prompt( |
|
|
prompt=prompt, |
|
|
prompt_2=None, |
|
|
max_sequence_length=256) |
|
|
del pipeline |
|
|
empty_cache() |
|
|
return outputs |
|
|
|
|
|
|
|
|
def infer_latents(_pipeline, prompt_embeds, pooled_prompt_embeds, width: int | None, height: int | None, seed: int | None): |
|
|
pipeline = FluxPipeline.from_pipe( |
|
|
_pipeline, |
|
|
text_encoder=None, |
|
|
text_encoder_2=None, |
|
|
tokenizer=None, |
|
|
tokenizer_2=None, |
|
|
vae=None, |
|
|
).to("cuda") |
|
|
|
|
|
if seed is None: |
|
|
generator = None |
|
|
else: |
|
|
generator = Generator(pipeline.device).manual_seed(seed) |
|
|
outputs = pipeline( |
|
|
prompt_embeds=prompt_embeds, |
|
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
|
num_inference_steps=4, |
|
|
guidance_scale=0.0, |
|
|
width=width, |
|
|
height=height, |
|
|
generator=generator, |
|
|
output_type="latent", |
|
|
).images |
|
|
del pipeline |
|
|
empty_cache() |
|
|
return outputs |
|
|
|
|
|
|
|
|
def decode_latents(vae, latents, width, height): |
|
|
vae.to("cuda") |
|
|
vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) |
|
|
width = width or 64 * vae_scale_factor |
|
|
height = height or 64 * vae_scale_factor |
|
|
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) |
|
|
with torch.no_grad(): |
|
|
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) |
|
|
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor |
|
|
image = vae.decode(latents, return_dict=False)[0] |
|
|
return image_processor.postprocess(image, output_type="pil")[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image: |
|
|
if request.seed is None: |
|
|
generator = None |
|
|
else: |
|
|
generator = Generator(device="cuda").manual_seed(request.seed) |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
image = _pipeline(prompt=request.prompt, |
|
|
width=request.width, |
|
|
height=request.height, |
|
|
guidance_scale=0.0, |
|
|
generator=generator, |
|
|
output_type="pil", |
|
|
max_sequence_length=256, |
|
|
num_inference_steps=NUM_STEPS).images[0] |
|
|
return image |
|
|
|
|
|
if __name__ == "__main__": |
|
|
request = TextToImageRequest(prompt=PROMPT, |
|
|
height=None, |
|
|
width=None, |
|
|
seed=666) |
|
|
start_time = perf_counter() |
|
|
pipe_ = load_pipeline() |
|
|
stop_time = perf_counter() |
|
|
print(f"Pipeline is loaded in {stop_time - start_time}s") |
|
|
for _ in range(4): |
|
|
start_time = perf_counter() |
|
|
infer(request, pipe_) |
|
|
stop_time = perf_counter() |
|
|
print(f"Request in {stop_time - start_time}s") |
|
|
|
|
|
|
|
|
|