supasubmitflux4 / src /pipeline.py
Manoj Bhat
Initial commit
c9df743
from diffusers import FluxPipeline, AutoencoderTiny
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import DiffusionPipeline
Pipeline = None
ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> Pipeline:
empty_cache()
dtype= torch.bfloat16
vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", torch_dtype=dtype)
text_encoder = CLIPTextModel.from_pretrained(
ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
)
text_encoder_2 = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
)
empty_cache()
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
vae=vae,
torch_dtype=dtype,
)
pipeline.enable_sequential_cpu_offload()
# warmup
for _ in range(2):
gc.collect()
pipeline(prompt="", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4 )
return pipeline
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
gc.collect()
try:
generator = Generator("cuda").manual_seed(request.seed)
image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
except:
image = img.open("./backup.png")
pass
return(image)