lanthanide_accurate / src /pipeline.py
manbeast3b's picture
Update src/pipeline.py
d000ac9 verified
import os
import gc
import torch
from PIL.Image import Image
from torch import Generator
from typing import Optional, Dict, Any
from diffusers import DiffusionPipeline
from transformers import T5EncoderModel
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest
from diffusers import FluxTransformer2DModel
# Environment configuration
MODEL_CONFIG = {
"repository": "black-forest-labs/FLUX.1-schnell",
"revision": "741f7c3ce8b383c54771c7003378a50191e9efe9",
"compute_device": "cuda",
"precision": torch.bfloat16,
"memory_allocation": "expandable_segments:True"
}
# Setup CUDA optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = MODEL_CONFIG["memory_allocation"]
def reclaim_memory():
"""Release unused GPU memory resources"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def acquire_text_encoder() -> T5EncoderModel:
"""Fetch and prepare the text encoder component"""
encoder_params = {
"pretrained_model_name_or_path": "manbeast3b/flux.1-schnell-full1",
"revision": "cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
"subfolder": "text_encoder_2",
"torch_dtype": MODEL_CONFIG["precision"]
}
return T5EncoderModel.from_pretrained(**encoder_params)
def acquire_transformer() -> FluxTransformer2DModel:
"""Fetch and prepare the transformer component"""
cache_location = os.path.join(
HF_HUB_CACHE,
"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
"transformer"
)
transformer = FluxTransformer2DModel.from_pretrained(
cache_location,
torch_dtype=MODEL_CONFIG["precision"],
use_safetensors=False
)
return transformer.to(memory_format=torch.channels_last)
def initialize_pipeline(components: Optional[Dict[str, Any]] = None) -> DiffusionPipeline:
"""Construct and initialize the diffusion pipeline"""
if components is None:
components = {}
if "text_encoder_2" not in components:
components["text_encoder_2"] = acquire_text_encoder()
if "transformer" not in components:
components["transformer"] = acquire_transformer()
# Create pipeline with components
pipeline = DiffusionPipeline.from_pretrained(
MODEL_CONFIG["repository"],
revision=MODEL_CONFIG["revision"],
torch_dtype=MODEL_CONFIG["precision"],
**components
)
# Configure pipeline
pipeline.to(MODEL_CONFIG["compute_device"])
pipeline.to(memory_format=torch.channels_last)
# Warm up with empty prompts
for _ in range(2):
with torch.no_grad():
pipeline(prompt=" ")
return pipeline
def load_pipeline() -> DiffusionPipeline:
"""
Public interface to load the model pipeline
Returns:
A configured diffusion pipeline ready for inference
"""
return initialize_pipeline()
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
"""
Generate an image from a text prompt
Args:
request: The text-to-image generation request
pipeline: The diffusion pipeline
generator: Random number generator with seed
Returns:
A PIL image generated from the prompt
"""
generation_params = {
"prompt": request.prompt,
"generator": generator,
"guidance_scale": 0.0,
"num_inference_steps": 4,
"max_sequence_length": 256,
"height": request.height,
"width": request.width,
"output_type": "pil"
}
result = pipeline(**generation_params)
return result.images[0]
# Alias for backward compatibility
load = load_pipeline