# eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee import os import torch import torch._dynamo import gc torch._dynamo.config.suppress_errors = True os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" from huggingface_hub.constants import HF_HUB_CACHE from torch import Generator from diffusers import FluxTransformer2DModel, DiffusionPipeline from PIL.Image import Image from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny from pipelines.models import TextToImageRequest from optimum.quanto import requantize import json import transformers from functools import wraps torch._dynamo.config.suppress_errors = True os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" MAIN_ID = "RichardWilliam/FullyFLUXSCH" REV = "c5f4f70c6cb9228a9c258799aadc660dde417af6" Pipeline = None apply_quanto=1 def to_hell(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def error_handler(func): @wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: print(f"Error in {func.__name__}: {str(e)}") return None return wrapper @error_handler def load_quanto_text_encoder_2(text_repo_path): with open("quantization_map.json", "r") as f: quantization_map = json.load(f) with open(os.path.join(text_repo_path, "config.json"), "r") as f: t5_config = transformers.T5Config(**json.load(f)) with torch.device("meta"): text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16) state_dict = None requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda")) return text_encoder_2 def load_pipeline() -> Pipeline: main_path = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_Transfomer/snapshots/6860c51af40329808f270e159a0d018559a1204f") origin_trans = FluxTransformer2DModel.from_pretrained(main_path, torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last) transformer = origin_trans pipeline = DiffusionPipeline.from_pretrained(MAIN_ID, revision=REV, transformer=transformer, torch_dtype=torch.bfloat16) pipeline.to("cuda") text_encoder_v2 = load_quanto_text_encoder_2(text_repo_path=None) if text_encoder_v2==None: print("Something wrong") else: pipeline.text_encoder_2 = text_encoder_v2 for __ in range(3): pipeline(prompt="I am the worst", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256) return pipeline @torch.no_grad() def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: to_hell() generator = Generator(pipeline.device).manual_seed(request.seed) return pipeline( request.prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, ).images[0]