import torch import torch._dynamo import gc import os from huggingface_hub.constants import HF_HUB_CACHE from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny from PIL.Image import Image from pipelines.models import TextToImageRequest from torch import Generator from diffusers import FluxTransformer2DModel, DiffusionPipeline from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only # Environment configuration os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" torch._dynamo.config.suppress_errors = True # Constants PIPELINE_MODEL_ID = "black-forest-labs/FLUX.1-schnell" PIPELINE_REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9" TEXT_MODEL_ID = "Chucklee/extra1_ste1" TEXT_MODEL_REVISION = "b0c1ffee1c1bdb3d30df17835615d809b7b8d075" EXTRA_MODEL_ID = "Chucklee/extra2_ste2" EXTRA_MODEL_REVISION = "3bfa327be3b38ee6f9c3ca7a5bfea6beeaa9306c" TRANSFORMER_SNAPSHOT = "ed7260988c4cc0b3bcab5d1318997fd6fa99345b" DEFAULT_PROMPT = "satiety, unwitherable, Pygmy, ramlike, Curtis, fingerstone, rewhisper" def load_pipeline() -> DiffusionPipeline: """Loads and initializes the diffusion pipeline.""" vae_model = AutoencoderKL.from_pretrained( PIPELINE_MODEL_ID, revision=PIPELINE_REVISION, subfolder="vae", local_files_only=True, torch_dtype=torch.bfloat16, ) quantize_(vae_model, int8_weight_only()) text_encoder = T5EncoderModel.from_pretrained( EXTRA_MODEL_ID, revision=EXTRA_MODEL_REVISION, torch_dtype=torch.bfloat16, ).to(memory_format=torch.channels_last) transformer_path = os.path.join( HF_HUB_CACHE, f"models--Chucklee--extra0_ste0/snapshots/{TRANSFORMER_SNAPSHOT}" ) transformer_model = FluxTransformer2DModel.from_pretrained( transformer_path, torch_dtype=torch.bfloat16, use_safetensors=False ).to(memory_format=torch.channels_last) diffusion_pipeline = DiffusionPipeline.from_pretrained( PIPELINE_MODEL_ID, revision=PIPELINE_REVISION, transformer=transformer_model, text_encoder_2=text_encoder, torch_dtype=torch.bfloat16, ) diffusion_pipeline.to("cuda") for _ in range(2): diffusion_pipeline( prompt=DEFAULT_PROMPT, width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, ) return diffusion_pipeline @torch.no_grad() def generate_image(request: TextToImageRequest, pipeline: DiffusionPipeline) -> Image: """Generates an image based on the input request and pipeline.""" generator = Generator(pipeline.device).manual_seed(request.seed) prompt = request.prompt if request.prompt else DEFAULT_PROMPT return pipeline( prompt=prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, ).images[0]