coldkey2 / src /pipeline.py
Your Name
Initial commit
2cc59cd
raw
history blame
3.81 kB
# FLux Optimization Pipeline
import os
import torch
import torch._dynamo
import gc
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
from optimum.quanto import requantize
import json
import transformers
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
Pipeline = None
apply_quanto=1
def reset_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_quanto_text_encoder_2(text_repo_path):
with open("quantization_map.json", "r") as f:
quantization_map = json.load(f)
with open(os.path.join(text_repo_path, "config.json"), "r") as f:
t5_config = transformers.T5Config(**json.load(f))
with torch.device("meta"):
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
state_dict = None
requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda"))
return text_encoder_2
def load_pipeline() -> Pipeline:
try:
text_repo_path = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_T5_bf16/snapshots/63a3d9ef7b586655600ac9bd4e4747d038237761")
text_encoder_2 = load_quanto_text_encoder_2(text_repo_path=text_repo_path)
except:
text_encoder_2 = T5EncoderModel.from_pretrained("RichardWilliam/XULF_T5_bf16",
revision = "63a3d9ef7b586655600ac9bd4e4747d038237761",
torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
origin_vae = AutoencoderTiny.from_pretrained("RichardWilliam/XULF_Vae",
revision="3ee225c539465c27adadec45c6e8af50a7397b7d",
torch_dtype=torch.bfloat16)
main_path = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_Transfomer/snapshots/6860c51af40329808f270e159a0d018559a1204f")
origin_trans = FluxTransformer2DModel.from_pretrained(main_path,
torch_dtype=torch.bfloat16,
use_safetensors=False).to(memory_format=torch.channels_last)
transformer = origin_trans
pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
revision=REVISION,
vae=origin_vae,
transformer=transformer,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16)
pipeline.to("cuda")
for __ in range(3):
pipeline(prompt="sweet, subordinative, gender, mormyre, arteriolosclerosis, positivism, Antiochianism, palmerite",
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
reset_cache()
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]