LunarPhase / src /pipeline.py
Chrissy1's picture
Initial commit with folder contents
f65f896 verified
import os
import gc
import torch
import numpy as np
from PIL import Image
from typing import Optional
from diffusers import (
DiffusionPipeline,
AutoencoderKL,
FluxPipeline,
FluxTransformer2DModel
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import (
T5EncoderModel,
T5TokenizerFast,
CLIPTokenizer,
CLIPTextModel
)
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only
# Pre-configurations
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
# Global variables
Pipeline = None
CKPT_ID = "black-forest-labs/FLUX.1-schnell"
CKPT_REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
def empty_cache():
"""Utility function to clear GPU memory."""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> FluxPipeline:
"""Loads the diffusion pipeline with specified models and configurations."""
# Load text encoder
text_encoder_2 = T5EncoderModel.from_pretrained(
"Chrissy1/extra0manQ0",
revision="c0db1e82d89825a4664ad873f20d261cbe46e737",
subfolder="text_encoder_2",
torch_dtype=torch.bfloat16
).to(memory_format=torch.channels_last)
# Load transformer
transformer_path = os.path.join(
HF_HUB_CACHE,
"models--Chrissy1--extra0manQ0/snapshots/c0db1e82d89825a4664ad873f20d261cbe46e737/transformer"
)
transformer = FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=torch.bfloat16,
use_safetensors=False
).to(memory_format=torch.channels_last)
# Load and quantize autoencoder
vae = AutoencoderKL.from_pretrained(
CKPT_ID,
revision=CKPT_REVISION,
subfolder="vae",
local_files_only=True,
torch_dtype=torch.bfloat16
)
quantize_(vae, int8_weight_only())
# Load FluxPipeline
pipeline = FluxPipeline.from_pretrained(
CKPT_ID,
revision=CKPT_REVISION,
transformer=transformer,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
# Warm-up run to ensure the pipeline is ready
with torch.inference_mode():
pipeline(
prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus",
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: FluxPipeline, generator: Generator) -> Image:
"""Generates an image based on the input request and pipeline."""
empty_cache() # Clear cache before inference
result = pipeline(
prompt=request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil"
)
return result.images[0]