import os import gc import torch from PIL.Image import Image from torch import Generator from typing import Optional, Dict, Any from diffusers import DiffusionPipeline from transformers import T5EncoderModel from huggingface_hub.constants import HF_HUB_CACHE from pipelines.models import TextToImageRequest from diffusers import FluxTransformer2DModel # Environment configuration MODEL_CONFIG = { "repository": "black-forest-labs/FLUX.1-schnell", "revision": "741f7c3ce8b383c54771c7003378a50191e9efe9", "compute_device": "cuda", "precision": torch.bfloat16, "memory_allocation": "expandable_segments:True" } # Setup CUDA optimizations torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True os.environ['PYTORCH_CUDA_ALLOC_CONF'] = MODEL_CONFIG["memory_allocation"] def reclaim_memory(): """Release unused GPU memory resources""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def acquire_text_encoder() -> T5EncoderModel: """Fetch and prepare the text encoder component""" encoder_params = { "pretrained_model_name_or_path": "manbeast3b/flux.1-schnell-full1", "revision": "cb1b599b0d712b9aab2c4df3ad27b050a27ec146", "subfolder": "text_encoder_2", "torch_dtype": MODEL_CONFIG["precision"] } return T5EncoderModel.from_pretrained(**encoder_params) def acquire_transformer() -> FluxTransformer2DModel: """Fetch and prepare the transformer component""" cache_location = os.path.join( HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146", "transformer" ) transformer = FluxTransformer2DModel.from_pretrained( cache_location, torch_dtype=MODEL_CONFIG["precision"], use_safetensors=False ) return transformer.to(memory_format=torch.channels_last) def initialize_pipeline(components: Optional[Dict[str, Any]] = None) -> DiffusionPipeline: """Construct and initialize the diffusion pipeline""" if components is None: components = {} if "text_encoder_2" not in components: components["text_encoder_2"] = acquire_text_encoder() if "transformer" not in components: components["transformer"] = acquire_transformer() # Create pipeline with components pipeline = DiffusionPipeline.from_pretrained( MODEL_CONFIG["repository"], revision=MODEL_CONFIG["revision"], torch_dtype=MODEL_CONFIG["precision"], **components ) # Configure pipeline pipeline.to(MODEL_CONFIG["compute_device"]) pipeline.to(memory_format=torch.channels_last) # Warm up with empty prompts for _ in range(2): with torch.no_grad(): pipeline(prompt=" ") return pipeline def load_pipeline() -> DiffusionPipeline: """ Public interface to load the model pipeline Returns: A configured diffusion pipeline ready for inference """ return initialize_pipeline() @torch.no_grad() def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image: """ Generate an image from a text prompt Args: request: The text-to-image generation request pipeline: The diffusion pipeline generator: Random number generator with seed Returns: A PIL image generated from the prompt """ generation_params = { "prompt": request.prompt, "generator": generator, "guidance_scale": 0.0, "num_inference_steps": 4, "max_sequence_length": 256, "height": request.height, "width": request.width, "output_type": "pil" } result = pipeline(**generation_params) return result.images[0] # Alias for backward compatibility load = load_pipeline