taofluxing4 / src /pipeline.py
manbeast3b's picture
Update src/pipeline.py
a7ab256 verified
from diffusers import AutoencoderTiny
from transformers import T5EncoderModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
Pipeline = None
MODEL_ID = "black-forest-labs/FLUX.1-schnell"
DTYPE = torch.bfloat16
def clear():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> Pipeline:
clear()
text_encoder_2 = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=DTYPE
)
vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=DTYPE)
quantize_(vae, int8_weight_only())
pipeline = DiffusionPipeline.from_pretrained(
MODEL_ID,
vae=vae,
text_encoder_2 = text_encoder_2,
torch_dtype=DTYPE,
)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.99)
pipeline.text_encoder.to(memory_format=torch.channels_last)
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.vae = torch.compile(pipeline.vae)
pipeline._exclude_from_cpu_offload = ["vae"]
pipeline.enable_sequential_cpu_offload()
for _ in range(1):
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
pipeline(prompt="", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
return pipeline
sample = True
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
global sample
if sample:
clear()
sample = None
torch.cuda.reset_peak_memory_stats()
generator = Generator("cuda").manual_seed(request.seed)
image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
return(image)