disburbedPanda2 / src /pipeline.py
manbeast3b's picture
Update src/pipeline.py
b83f575 verified
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only
import torchvision
#from torchao.quantization import autoquant
from model import Encoder, Decoder
Pipeline = None
DTYPE = torch.bfloat16
ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
start = time.time()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# print(f"Flush took: {time.time() - start}")
def load_pipeline() -> Pipeline:
empty_cache()
vae = AutoencoderTiny.from_pretrained("madebyollin/taef1")
vae.encoder = Encoder(16)
vae.decoder = Decoder(16)
encoder_path = "encoder.pth"
decoder_path = "decoder.pth"
if encoder_path is not None:
encoder_state_dict = torch.load(encoder_path, map_location="cpu", weights_only=True)
filtered_state_dict = {k.strip('encoder.'): v for k, v in encoder_state_dict.items() if k.strip('encoder.') in vae.encoder.state_dict() and v.size() == vae.encoder.state_dict()[k.strip('encoder.')].size()}
print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(vae.encoder.state_dict())}")
vae.encoder.load_state_dict(filtered_state_dict, strict=False)
if decoder_path is not None:
decoder_state_dict = torch.load(decoder_path, map_location="cpu", weights_only=True)
filtered_state_dict = {k.strip('decoder.'): v for k, v in decoder_state_dict.items() if k.strip('decoder.') in vae.decoder.state_dict() and v.size() == vae.decoder.state_dict()[k.strip('decoder.')].size()}
print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(vae.decoder.state_dict())}")
vae.decoder.load_state_dict(filtered_state_dict, strict=False)
vae.decoder.requires_grad_(False)
vae.encoder.requires_grad_(False)
vae.to(dtype=DTYPE)
empty_cache()
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
vae = vae,
torch_dtype=DTYPE)
pipeline.text_encoder.to(memory_format=torch.channels_last)
pipeline.text_encoder_2.to(memory_format=torch.channels_last)
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
pipeline._exclude_from_cpu_offload = ["vae"]
pipeline.enable_sequential_cpu_offload()
for _ in range(2):
empty_cache()
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
return pipeline
sample = True
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
global sample
if sample:
empty_cache()
sample = None
# torch.cuda.reset_peak_memory_stats()
generator = Generator("cuda").manual_seed(request.seed)
# image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pt").images[0]
return torchvision.transforms.functional.to_pil_image(image.to(torch.float32).mul_(2).sub_(1))
# return(image)