Spaces:
Paused
Paused
| from diffusers.pipelines import FluxPipeline | |
| from diffusers.utils import logging | |
| from diffusers.pipelines.flux.pipeline_flux import logger | |
| from torch import Tensor | |
| from typing import Optional, Dict, Any | |
| from .padding_orthogonalization import apply_padding_token_orthogonalization | |
| def encode_images(pipeline: FluxPipeline, images: Tensor): | |
| images = pipeline.image_processor.preprocess(images) | |
| images = images.to(pipeline.device).to(pipeline.dtype) | |
| images = pipeline.vae.encode(images).latent_dist.sample() | |
| images = ( | |
| images - pipeline.vae.config.shift_factor | |
| ) * pipeline.vae.config.scaling_factor | |
| images_tokens = pipeline._pack_latents(images, *images.shape) | |
| images_ids = pipeline._prepare_latent_image_ids( | |
| images.shape[0], | |
| images.shape[2], | |
| images.shape[3], | |
| pipeline.device, | |
| pipeline.dtype, | |
| ) | |
| if images_tokens.shape[1] != images_ids.shape[0]: | |
| images_ids = pipeline._prepare_latent_image_ids( | |
| images.shape[0], | |
| images.shape[2] // 2, | |
| images.shape[3] // 2, | |
| pipeline.device, | |
| pipeline.dtype, | |
| ) | |
| return images_tokens, images_ids | |
| def prepare_text_input( | |
| pipeline: FluxPipeline, | |
| prompts, | |
| max_sequence_length=512, | |
| model_config: Optional[Dict[str, Any]] = None | |
| ): | |
| """ | |
| Prepare text input with optional padding token orthogonalization. | |
| Args: | |
| pipeline: FluxPipeline instance | |
| prompts: Text prompts to encode | |
| max_sequence_length: Maximum sequence length | |
| model_config: Optional configuration for orthogonalization | |
| Returns: | |
| Tuple of (prompt_embeds, pooled_prompt_embeds, text_ids) | |
| """ | |
| # Turn off warnings (CLIP overflow) | |
| logger.setLevel(logging.ERROR) | |
| ( | |
| prompt_embeds, | |
| pooled_prompt_embeds, | |
| text_ids, | |
| ) = pipeline.encode_prompt( | |
| prompt=prompts, | |
| prompt_2=None, | |
| prompt_embeds=None, | |
| pooled_prompt_embeds=None, | |
| device=pipeline.device, | |
| num_images_per_prompt=1, | |
| max_sequence_length=max_sequence_length, | |
| lora_scale=None, | |
| ) | |
| # Apply padding token orthogonalization if configured | |
| if model_config and model_config.get('padding_orthogonalization_enabled', False): | |
| prompt_embeds = apply_padding_token_orthogonalization( | |
| prompt_embeds=prompt_embeds, | |
| text_attention_mask=None, | |
| config=model_config, | |
| ) | |
| # Turn on warnings | |
| logger.setLevel(logging.WARNING) | |
| return prompt_embeds, pooled_prompt_embeds, text_ids | |