File size: 3,983 Bytes
3770736 ff95496 3770736 d000ac9 3770736 ff95496 d000ac9 ff95496 d000ac9 ff95496 d000ac9 3770736 d000ac9 3770736 d000ac9 3770736 d000ac9 3770736 d000ac9 3770736 d000ac9 3770736 d000ac9 3770736 d000ac9 3770736 d000ac9 ff95496 3770736 d000ac9 ff95496 d000ac9 ff95496 3770736 d000ac9 3770736 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | import os
import gc
import torch
from PIL.Image import Image
from torch import Generator
from typing import Optional, Dict, Any
from diffusers import DiffusionPipeline
from transformers import T5EncoderModel
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest
from diffusers import FluxTransformer2DModel
# Environment configuration
MODEL_CONFIG = {
"repository": "black-forest-labs/FLUX.1-schnell",
"revision": "741f7c3ce8b383c54771c7003378a50191e9efe9",
"compute_device": "cuda",
"precision": torch.bfloat16,
"memory_allocation": "expandable_segments:True"
}
# Setup CUDA optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = MODEL_CONFIG["memory_allocation"]
def reclaim_memory():
"""Release unused GPU memory resources"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def acquire_text_encoder() -> T5EncoderModel:
"""Fetch and prepare the text encoder component"""
encoder_params = {
"pretrained_model_name_or_path": "manbeast3b/flux.1-schnell-full1",
"revision": "cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
"subfolder": "text_encoder_2",
"torch_dtype": MODEL_CONFIG["precision"]
}
return T5EncoderModel.from_pretrained(**encoder_params)
def acquire_transformer() -> FluxTransformer2DModel:
"""Fetch and prepare the transformer component"""
cache_location = os.path.join(
HF_HUB_CACHE,
"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
"transformer"
)
transformer = FluxTransformer2DModel.from_pretrained(
cache_location,
torch_dtype=MODEL_CONFIG["precision"],
use_safetensors=False
)
return transformer.to(memory_format=torch.channels_last)
def initialize_pipeline(components: Optional[Dict[str, Any]] = None) -> DiffusionPipeline:
"""Construct and initialize the diffusion pipeline"""
if components is None:
components = {}
if "text_encoder_2" not in components:
components["text_encoder_2"] = acquire_text_encoder()
if "transformer" not in components:
components["transformer"] = acquire_transformer()
# Create pipeline with components
pipeline = DiffusionPipeline.from_pretrained(
MODEL_CONFIG["repository"],
revision=MODEL_CONFIG["revision"],
torch_dtype=MODEL_CONFIG["precision"],
**components
)
# Configure pipeline
pipeline.to(MODEL_CONFIG["compute_device"])
pipeline.to(memory_format=torch.channels_last)
# Warm up with empty prompts
for _ in range(2):
with torch.no_grad():
pipeline(prompt=" ")
return pipeline
def load_pipeline() -> DiffusionPipeline:
"""
Public interface to load the model pipeline
Returns:
A configured diffusion pipeline ready for inference
"""
return initialize_pipeline()
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
"""
Generate an image from a text prompt
Args:
request: The text-to-image generation request
pipeline: The diffusion pipeline
generator: Random number generator with seed
Returns:
A PIL image generated from the prompt
"""
generation_params = {
"prompt": request.prompt,
"generator": generator,
"guidance_scale": 0.0,
"num_inference_steps": 4,
"max_sequence_length": 256,
"height": request.height,
"width": request.width,
"output_type": "pil"
}
result = pipeline(**generation_params)
return result.images[0]
# Alias for backward compatibility
load = load_pipeline |