newyearspice1 / src /pipeline.py
manbeast3b's picture
Update src/pipeline.py
d949d1c verified
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
Pipeline = None
'''
ckpt_id = "RobertML/FLUX.1-schnell-qf8"
ckpt_revision = "f360ee74b68f38c0b8abd873d0d5800509ed62a2"
def empty_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def trace_and_save_vae_decoder(vae, latents):
try:
traced_vae_decode = torch.jit.trace(vae.decode, (latents, True))
torch.jit.save(traced_vae_decode, traced_vae_decode_path)
return traced_vae_decode
except Exception as e:
# print(f"JIT tracing failed: {e}")
return vae.decode #Fall back to untraced decoder.
def load_pipeline() -> Pipeline:
empty_cache()
dtype, device = torch.bfloat16, "cuda"
text_encoder_2 = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16
).to(memory_format=torch.channels_last)
vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=dtype)
path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False).to(memory_format=torch.channels_last)
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
vae=vae,
revision=ckpt_revision,
transformer=model,
text_encoder_2=text_encoder_2,
torch_dtype=dtype,
).to(device)
#quantize_(pipeline.vae, int8_weight_only())
pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
for _ in range(3):
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
empty_cache()
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
def encode_prompt(prompt: str):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt,
prompt_2=None,
max_sequence_length=256,
)
return prompt_embeds, pooled_prompt_embeds, text_ids
def infer_latents(prompt_embeds, pooled_prompt_embeds, width: int | None, height: int | None, seed: int | None, generator):
if generator is None:
generator = Generator(pipeline.device).manual_seed(seed)
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=4,
guidance_scale=0.0,
width=width if width else 1024,
height=height if height else 1024,
generator=generator,
output_type="latent",
).images
return latents
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(request.prompt)
latents = infer_latents(prompt_embeds, pooled_prompt_embeds, request.width, request.height, request.seed, generator)
def decode_latents_to_image(latents, height: int, width: int, vae):
width = width if width else 1024
height = height if height else 1024
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
"""
# Try to load the traced model; trace and save if not found
traced_vae_decode_path = os.path.join(HF_HUB_CACHE, "decoder_trace")
if os.path.exists(traced_vae_decode_path):
try:
traced_vae_decode = torch.jit.load(traced_vae_decode_path)
print("Loaded traced VAE decoder from file.")
except Exception as e:
print(f"Error loading traced VAE decoder: {e}. Retracing...")
traced_vae_decode = trace_and_save_vae_decoder(vae, latents)
else:
traced_vae_decode = trace_and_save_vae_decoder(vae, latents)
"""
traced_vae_decode = vae.decode
with torch.no_grad():
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = traced_vae_decode(latents, return_dict=False)[0] # Use the traced function
decoded_image = image_processor.postprocess(image, output_type="pil")[0]
return decoded_image
# image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
decoded_image=decode_latents_to_image(latents, request.height, request.width, pipeline.vae)
return(decoded_image)
'''
Pipeline = None
ids = "slobers/Flux.1.Schnella"
Revision = "e34d670e44cecbbc90e4962e7aada2ac5ce8b55b"
def empty_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> Pipeline:
empty_cache()
path = os.path.join(HF_HUB_CACHE, "models--slobers--Flux.1.Schnella/snapshots/e34d670e44cecbbc90e4962e7aada2ac5ce8b55b/transformer")
transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)
pipeline = FluxPipeline.from_pretrained(ids, revision=Revision, transformer=transformer, local_files_only=True, torch_dtype=torch.bfloat16,)
pipeline.to("cuda")
# quantize_(pipeline.vae, int8_weight_only())
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
for _ in range(3):
pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
empty_cache()
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]