File size: 3,160 Bytes
089178f dfc70e6 089178f dfc70e6 089178f dfc70e6 089178f dfc70e6 089178f dfc70e6 089178f dfc70e6 089178f dfc70e6 089178f dfc70e6 089178f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | import torch
import torch._dynamo
import gc
import os
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
# Environment configuration
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
# Constants
PIPELINE_MODEL_ID = "black-forest-labs/FLUX.1-schnell"
PIPELINE_REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
TEXT_MODEL_ID = "Chucklee/extra1_ste1"
TEXT_MODEL_REVISION = "b0c1ffee1c1bdb3d30df17835615d809b7b8d075"
EXTRA_MODEL_ID = "Chucklee/extra2_ste2"
EXTRA_MODEL_REVISION = "3bfa327be3b38ee6f9c3ca7a5bfea6beeaa9306c"
TRANSFORMER_SNAPSHOT = "ed7260988c4cc0b3bcab5d1318997fd6fa99345b"
DEFAULT_PROMPT = "satiety, unwitherable, Pygmy, ramlike, Curtis, fingerstone, rewhisper"
def load_pipeline() -> DiffusionPipeline:
"""Loads and initializes the diffusion pipeline."""
vae_model = AutoencoderKL.from_pretrained(
PIPELINE_MODEL_ID,
revision=PIPELINE_REVISION,
subfolder="vae",
local_files_only=True,
torch_dtype=torch.bfloat16,
)
quantize_(vae_model, int8_weight_only())
text_encoder = T5EncoderModel.from_pretrained(
EXTRA_MODEL_ID,
revision=EXTRA_MODEL_REVISION,
torch_dtype=torch.bfloat16,
).to(memory_format=torch.channels_last)
transformer_path = os.path.join(
HF_HUB_CACHE, f"models--Chucklee--extra0_ste0/snapshots/{TRANSFORMER_SNAPSHOT}"
)
transformer_model = FluxTransformer2DModel.from_pretrained(
transformer_path, torch_dtype=torch.bfloat16, use_safetensors=False
).to(memory_format=torch.channels_last)
diffusion_pipeline = DiffusionPipeline.from_pretrained(
PIPELINE_MODEL_ID,
revision=PIPELINE_REVISION,
transformer=transformer_model,
text_encoder_2=text_encoder,
torch_dtype=torch.bfloat16,
)
diffusion_pipeline.to("cuda")
for _ in range(2):
diffusion_pipeline(
prompt=DEFAULT_PROMPT,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
)
return diffusion_pipeline
@torch.no_grad()
def generate_image(request: TextToImageRequest, pipeline: DiffusionPipeline) -> Image:
"""Generates an image based on the input request and pipeline."""
generator = Generator(pipeline.device).manual_seed(request.seed)
prompt = request.prompt if request.prompt else DEFAULT_PROMPT
return pipeline(
prompt=prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]
|