|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch._dynamo |
|
|
import gc |
|
|
|
|
|
torch._dynamo.config.suppress_errors = True |
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "True" |
|
|
|
|
|
|
|
|
from huggingface_hub.constants import HF_HUB_CACHE |
|
|
from torch import Generator |
|
|
from diffusers import FluxTransformer2DModel, DiffusionPipeline |
|
|
|
|
|
from PIL.Image import Image |
|
|
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny |
|
|
from pipelines.models import TextToImageRequest |
|
|
from optimum.quanto import requantize |
|
|
import json |
|
|
import transformers |
|
|
from functools import wraps |
|
|
|
|
|
|
|
|
|
|
|
torch._dynamo.config.suppress_errors = True |
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "True" |
|
|
|
|
|
MAIN_ID = "RichardWilliam/FullyFLUXSCH" |
|
|
REV = "c5f4f70c6cb9228a9c258799aadc660dde417af6" |
|
|
Pipeline = None |
|
|
apply_quanto=1 |
|
|
|
|
|
def to_hell(): |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_max_memory_allocated() |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
|
|
def error_handler(func): |
|
|
@wraps(func) |
|
|
def wrapper(*args, **kwargs): |
|
|
try: |
|
|
return func(*args, **kwargs) |
|
|
except Exception as e: |
|
|
print(f"Error in {func.__name__}: {str(e)}") |
|
|
return None |
|
|
return wrapper |
|
|
|
|
|
@error_handler |
|
|
def load_quanto_text_encoder_2(text_repo_path): |
|
|
with open("quantization_map.json", "r") as f: |
|
|
quantization_map = json.load(f) |
|
|
with open(os.path.join(text_repo_path, "config.json"), "r") as f: |
|
|
t5_config = transformers.T5Config(**json.load(f)) |
|
|
with torch.device("meta"): |
|
|
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16) |
|
|
state_dict = None |
|
|
requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda")) |
|
|
return text_encoder_2 |
|
|
|
|
|
|
|
|
def load_pipeline() -> Pipeline: |
|
|
|
|
|
main_path = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_Transfomer/snapshots/6860c51af40329808f270e159a0d018559a1204f") |
|
|
origin_trans = FluxTransformer2DModel.from_pretrained(main_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
use_safetensors=False).to(memory_format=torch.channels_last) |
|
|
|
|
|
transformer = origin_trans |
|
|
|
|
|
pipeline = DiffusionPipeline.from_pretrained(MAIN_ID, |
|
|
revision=REV, |
|
|
transformer=transformer, |
|
|
torch_dtype=torch.bfloat16) |
|
|
pipeline.to("cuda") |
|
|
|
|
|
|
|
|
text_encoder_v2 = load_quanto_text_encoder_2(text_repo_path=None) |
|
|
|
|
|
if text_encoder_v2==None: |
|
|
print("Something wrong") |
|
|
else: |
|
|
pipeline.text_encoder_2 = text_encoder_v2 |
|
|
|
|
|
for __ in range(3): |
|
|
pipeline(prompt="I am the worst", |
|
|
width=1024, |
|
|
height=1024, |
|
|
guidance_scale=0.0, |
|
|
num_inference_steps=4, |
|
|
max_sequence_length=256) |
|
|
return pipeline |
|
|
|
|
|
@torch.no_grad() |
|
|
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: |
|
|
|
|
|
to_hell() |
|
|
|
|
|
generator = Generator(pipeline.device).manual_seed(request.seed) |
|
|
|
|
|
return pipeline( |
|
|
request.prompt, |
|
|
generator=generator, |
|
|
guidance_scale=0.0, |
|
|
num_inference_steps=4, |
|
|
max_sequence_length=256, |
|
|
height=request.height, |
|
|
width=request.width, |
|
|
).images[0] |