File size: 3,894 Bytes
372586c daa1640 372586c 58b1cf8 61d3add 58b1cf8 372586c 4318f6f 372586c 2c03636 ab415dd 61d3add ddfef1f 61d3add 7d276c3 61d3add 7d276c3 61d3add 372586c 61d3add 372586c 7d276c3 61d3add 372586c f7d393e 372586c b08a178 372586c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import BitsAndBytesConfig
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_,int8_weight_only
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
Pipeline = None
# Define the quantization config
# nf4_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_use_double_quant=True,
# bnb_4bit_compute_dtype=torch.bfloat16
# )
config = BitsAndBytesConfig(load_in_8bit=True)
ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
start = time.time()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
print(f"Flush took: {time.time() - start}")
cache_dir = "/root/.cache/huggingface/hub/models--manbeast3b--flux-schnell-int8/snapshots/eb656b7968de3088ccac7cda876f5782e5a2f721/"
def load_pipeline() -> Pipeline:
empty_cache()
dtype, device = torch.bfloat16, "cuda"
# text_encoder_2 = T5EncoderModel.from_pretrained(
# "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
# )
# text_encoder_2 = T5EncoderModel.from_pretrained(
# "sayakpaul/flux.1-dev-nf4-pkg", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
# )
# text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
model_id = "manbeast3b/flux-schnell-int8"
transformer = FluxTransformer2DModel.from_pretrained(
cache_dir, subfolder="transformer", torch_dtype=torch.bfloat16, quantization_config=config
)
text_encoder_2 = T5EncoderModel.from_pretrained(
cache_dir, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, quantization_config=config
)
text_encoder = CLIPTextModel.from_pretrained(
cache_dir, subfolder="text_encoder",torch_dtype=torch.bfloat16, quantization_config=config
)
# vae=AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype)
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
# vae=vae,
transformer = transformer,
text_encoder = text_encoder,
text_encoder_2 = text_encoder_2,
torch_dtype=dtype,
)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.95)
# pipeline.text_encoder.to(memory_format=torch.channels_last)
# pipeline.transformer.to(memory_format=torch.channels_last)
# torch.jit.enable_onednn_fusion(True)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.vae = torch.compile(pipeline.vae)
pipeline._exclude_from_cpu_offload = ["vae"]
pipeline.enable_sequential_cpu_offload()
for _ in range(2):
pipeline(prompt="warmup run testing one two three", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
return pipeline
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
torch.cuda.reset_peak_memory_stats()
generator = Generator("cuda").manual_seed(request.seed)
image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
return(image)
|