from torch import Generator from diffusers import FluxTransformer2DModel, DiffusionPipeline, AutoencoderTiny from PIL.Image import Image from pipelines.models import TextToImageRequest from huggingface_hub.constants import HF_HUB_CACHE from transformers import T5EncoderModel import torch import torch._dynamo import os # Environment optimizations os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" torch._dynamo.config.suppress_errors = True pipeline_class = None model_checkpoint = "black-forest-labs/FLUX.1-schnell" model_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9" class NormalizationQuantization: def __init__(self, model, noise_level=0.05): self.model = model self.noise_level = noise_level def apply(self): for param_name, param in self.model.named_parameters(): if param.requires_grad: with torch.no_grad(): noise = torch.randn_like(param.data) * self.noise_level param.data = torch.floor(param.data + noise) for buffer_name, buffer in self.model.named_buffers(): with torch.no_grad(): buffer.add_(torch.full_like(buffer, 0.01)) return self.model def load_diffusion_pipeline() -> pipeline_class: vae_model = AutoencoderTiny.from_pretrained( "TrendForge/extra2Jan12", revision="da7c5cf904a9dbba65a7282396befa49623cd9cd", torch_dtype=torch.bfloat16 ) base_text_encoder = T5EncoderModel.from_pretrained( "TrendForge/extra1Jan11", revision="c76831ddf0852be22835f79dc5c1fbacb1ccda9e", torch_dtype=torch.bfloat16 ).to(memory_format=torch.channels_last) # Apply normalization quantization to text encoder try: text_encoder = NormalizationQuantization(base_text_encoder, noise_level=0.03).apply() except Exception as e: print(f"Failed to apply normalization quantization on text encoder: {e}") text_encoder = base_text_encoder transformer_path = os.path.join( HF_HUB_CACHE, "models--TrendForge--extra0Jan10/snapshots/d3ded25a77fdef06de4059d94b080a34da6e7a82" ) base_transformer_model = FluxTransformer2DModel.from_pretrained( transformer_path, torch_dtype=torch.bfloat16, use_safetensors=False ).to(memory_format=torch.channels_last) # Apply normalization quantization to transformer try: transformer_model = NormalizationQuantization(base_transformer_model, noise_level=0.03).apply() except Exception as e: print(f"Failed to apply normalization quantization on transformer model: {e}") transformer_model = base_transformer_model diffusion_pipeline = DiffusionPipeline.from_pretrained( model_checkpoint, revision=model_revision, vae=vae_model, transformer=transformer_model, text_encoder_2=text_encoder, torch_dtype=torch.bfloat16 ) diffusion_pipeline.to("cuda") for _ in range(3): diffusion_pipeline( prompt="freezable, catacorolla, gaiassa, unenkindled, grubs, solidiform", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 ) return diffusion_pipeline @torch.no_grad() def perform_inference(request: TextToImageRequest, pipeline: pipeline_class) -> Image: generator = Generator(pipeline.device).manual_seed(request.seed) return pipeline( request.prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, ).images[0]