flux-super-compile-6 / src /pipeline.py
jokerbit's picture
Upload folder using huggingface_hub
21c5203 verified
import os
from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel, CLIPTextConfig, T5Config
import torch
import gc
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only, int8_dynamic_activation_int8_weight
from time import perf_counter
HOME = os.environ["HOME"]
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
FLUX_CHECKPOINT = "jokerbit/flux.1-schnell-city96"
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.99)
DTYPE = torch.bfloat16
NUM_STEPS = 4
def empty_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> FluxPipeline:
empty_cache()
pipe = FluxPipeline.from_pretrained(FLUX_CHECKPOINT,
torch_dtype=DTYPE)
pipe.text_encoder.to(memory_format=torch.channels_last)
pipe.text_encoder_2.to(memory_format=torch.channels_last)
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
# pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead")
pipe._exclude_from_cpu_offload = ["vae"]
pipe.enable_sequential_cpu_offload()
prompt = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
for _ in range(2):
empty_cache()
pipe(prompt, guidance_scale=0., max_sequence_length=256, num_inference_steps=4)
empty_cache()
return pipe
@torch.inference_mode()
def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image:
if request.seed is None:
generator = None
else:
generator = Generator(device="cuda").manual_seed(request.seed)
torch.cuda.reset_peak_memory_stats()
image = _pipeline(prompt=request.prompt,
width=request.width,
height=request.height,
guidance_scale=0.0,
generator=generator,
output_type="pil",
max_sequence_length=256,
num_inference_steps=NUM_STEPS).images[0]
return image