Midnight45 / src /pipeline.py
BrenL's picture
Initial commit with folder contents
ce03777 verified
import os
import torch
import torch._dynamo
from PIL.Image import Image
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel
from diffusers import (
AutoencoderKL,
DiffusionPipeline,
FluxTransformer2DModel,
)
from pipelines.models import TextToImageRequest
from torchao.quantization import quantize_, int8_weight_only
# Environment setup
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
# Constants
IDS = "black-forest-labs/FLUX.1-schnell"
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
TT_IMAGE_MODEL = "BrenL/extra1IMOO1"
TT_IMAGE_REVISION = "3e33f01cda8a8c207218c2d31853fdc08bebd38f"
EXTRA_TEXT_ENCODER = "BrenL/extra2IMOO2"
EXTRA_TEXT_REVISION = "f7538acf69d8b71458542b22257de6508850ab6d"
DEFAULT_PROMPT = "satiety, unwitherable, Pygmy, ramlike, Curtis, fingerstone, rewhisper"
def load_pipeline() -> DiffusionPipeline:
"""
Load and prepare the diffusion pipeline with quantization and required components.
"""
# Load components
vae = AutoencoderKL.from_pretrained(
IDS,
revision=REVISION,
subfolder="vae",
local_files_only=True,
torch_dtype=torch.bfloat16,
)
quantize_(vae, int8_weight_only())
text_encoder_2 = T5EncoderModel.from_pretrained(
EXTRA_TEXT_ENCODER,
revision=EXTRA_TEXT_REVISION,
torch_dtype=torch.bfloat16,
).to(memory_format=torch.channels_last)
transformer_path = os.path.join(
HF_HUB_CACHE,
"models--BrenL--extra0IMOO0/snapshots/422ee1f0f85ef1b035f00449540b254df85cd3a6",
)
transformer = FluxTransformer2DModel.from_pretrained(
transformer_path, torch_dtype=torch.bfloat16, use_safetensors=False
).to(memory_format=torch.channels_last)
# Build pipeline
pipeline = DiffusionPipeline.from_pretrained(
IDS,
revision=REVISION,
transformer=transformer,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16,
)
pipeline.to("cuda")
# Warm-up
for _ in range(2):
pipeline(
prompt=DEFAULT_PROMPT,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: DiffusionPipeline) -> Image:
"""
Perform inference using the diffusion pipeline.
Args:
request (TextToImageRequest): The input request containing parameters like prompt, seed, height, and width.
pipeline (DiffusionPipeline): The diffusion pipeline to use for inference.
Returns:
Image: Generated image.
"""
generator = torch.Generator(pipeline.device).manual_seed(request.seed)
prompt = request.prompt if hasattr(request, "prompt") else DEFAULT_PROMPT
return pipeline(
prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]