| import os |
| import gc |
| import torch |
| from PIL.Image import Image |
| from torch import Generator |
| from typing import Optional, Dict, Any |
| from diffusers import DiffusionPipeline |
| from transformers import T5EncoderModel |
| from huggingface_hub.constants import HF_HUB_CACHE |
| from pipelines.models import TextToImageRequest |
| from diffusers import FluxTransformer2DModel |
|
|
| |
| MODEL_CONFIG = { |
| "repository": "black-forest-labs/FLUX.1-schnell", |
| "revision": "741f7c3ce8b383c54771c7003378a50191e9efe9", |
| "compute_device": "cuda", |
| "precision": torch.bfloat16, |
| "memory_allocation": "expandable_segments:True" |
| } |
|
|
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.enabled = True |
| torch.backends.cudnn.benchmark = True |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = MODEL_CONFIG["memory_allocation"] |
|
|
| def reclaim_memory(): |
| """Release unused GPU memory resources""" |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.reset_max_memory_allocated() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| def acquire_text_encoder() -> T5EncoderModel: |
| """Fetch and prepare the text encoder component""" |
| encoder_params = { |
| "pretrained_model_name_or_path": "manbeast3b/flux.1-schnell-full1", |
| "revision": "cb1b599b0d712b9aab2c4df3ad27b050a27ec146", |
| "subfolder": "text_encoder_2", |
| "torch_dtype": MODEL_CONFIG["precision"] |
| } |
| return T5EncoderModel.from_pretrained(**encoder_params) |
|
|
| def acquire_transformer() -> FluxTransformer2DModel: |
| """Fetch and prepare the transformer component""" |
| cache_location = os.path.join( |
| HF_HUB_CACHE, |
| "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146", |
| "transformer" |
| ) |
|
|
| transformer = FluxTransformer2DModel.from_pretrained( |
| cache_location, |
| torch_dtype=MODEL_CONFIG["precision"], |
| use_safetensors=False |
| ) |
|
|
| return transformer.to(memory_format=torch.channels_last) |
|
|
| def initialize_pipeline(components: Optional[Dict[str, Any]] = None) -> DiffusionPipeline: |
| """Construct and initialize the diffusion pipeline""" |
| if components is None: |
| components = {} |
|
|
| if "text_encoder_2" not in components: |
| components["text_encoder_2"] = acquire_text_encoder() |
|
|
| if "transformer" not in components: |
| components["transformer"] = acquire_transformer() |
|
|
| |
| pipeline = DiffusionPipeline.from_pretrained( |
| MODEL_CONFIG["repository"], |
| revision=MODEL_CONFIG["revision"], |
| torch_dtype=MODEL_CONFIG["precision"], |
| **components |
| ) |
|
|
| |
| pipeline.to(MODEL_CONFIG["compute_device"]) |
| pipeline.to(memory_format=torch.channels_last) |
|
|
| |
| for _ in range(2): |
| with torch.no_grad(): |
| pipeline(prompt=" ") |
|
|
| return pipeline |
|
|
| def load_pipeline() -> DiffusionPipeline: |
| """ |
| Public interface to load the model pipeline |
| |
| Returns: |
| A configured diffusion pipeline ready for inference |
| """ |
| return initialize_pipeline() |
|
|
| @torch.no_grad() |
| def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image: |
| """ |
| Generate an image from a text prompt |
| |
| Args: |
| request: The text-to-image generation request |
| pipeline: The diffusion pipeline |
| generator: Random number generator with seed |
| |
| Returns: |
| A PIL image generated from the prompt |
| """ |
| generation_params = { |
| "prompt": request.prompt, |
| "generator": generator, |
| "guidance_scale": 0.0, |
| "num_inference_steps": 4, |
| "max_sequence_length": 256, |
| "height": request.height, |
| "width": request.width, |
| "output_type": "pil" |
| } |
|
|
| result = pipeline(**generation_params) |
| return result.images[0] |
|
|
| |
| load = load_pipeline |