Ad_gen / cascade /pipeline_tools.py
Flulike99's picture
cascade
6df1cf2
from diffusers.pipelines import FluxPipeline
from diffusers.utils import logging
from diffusers.pipelines.flux.pipeline_flux import logger
from torch import Tensor
from typing import Optional, Dict, Any
from .padding_orthogonalization import apply_padding_token_orthogonalization
def encode_images(pipeline: FluxPipeline, images: Tensor):
images = pipeline.image_processor.preprocess(images)
images = images.to(pipeline.device).to(pipeline.dtype)
images = pipeline.vae.encode(images).latent_dist.sample()
images = (
images - pipeline.vae.config.shift_factor
) * pipeline.vae.config.scaling_factor
images_tokens = pipeline._pack_latents(images, *images.shape)
images_ids = pipeline._prepare_latent_image_ids(
images.shape[0],
images.shape[2],
images.shape[3],
pipeline.device,
pipeline.dtype,
)
if images_tokens.shape[1] != images_ids.shape[0]:
images_ids = pipeline._prepare_latent_image_ids(
images.shape[0],
images.shape[2] // 2,
images.shape[3] // 2,
pipeline.device,
pipeline.dtype,
)
return images_tokens, images_ids
def prepare_text_input(
pipeline: FluxPipeline,
prompts,
max_sequence_length=512,
model_config: Optional[Dict[str, Any]] = None
):
"""
Prepare text input with optional padding token orthogonalization.
Args:
pipeline: FluxPipeline instance
prompts: Text prompts to encode
max_sequence_length: Maximum sequence length
model_config: Optional configuration for orthogonalization
Returns:
Tuple of (prompt_embeds, pooled_prompt_embeds, text_ids)
"""
# Turn off warnings (CLIP overflow)
logger.setLevel(logging.ERROR)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = pipeline.encode_prompt(
prompt=prompts,
prompt_2=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
device=pipeline.device,
num_images_per_prompt=1,
max_sequence_length=max_sequence_length,
lora_scale=None,
)
# Apply padding token orthogonalization if configured
if model_config and model_config.get('padding_orthogonalization_enabled', False):
prompt_embeds = apply_padding_token_orthogonalization(
prompt_embeds=prompt_embeds,
text_attention_mask=None,
config=model_config,
)
# Turn on warnings
logger.setLevel(logging.WARNING)
return prompt_embeds, pooled_prompt_embeds, text_ids