File size: 23,361 Bytes
56f2217 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 |
import hashlib
import os
from typing import List, Optional, Union
import torch
from diffusers import FluxModularPipeline, ModularPipelineBlocks
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.modular_pipelines import PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
InputParam,
OutputParam,
)
from diffusers.utils import (
USE_PEFT_BACKEND,
logger,
scale_lora_layers,
unscale_lora_layers,
)
from safetensors import safe_open
from safetensors.torch import save_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
class CachedFluxTextEncoderStep(ModularPipelineBlocks):
model_name = "flux"
def __init__(
self,
use_cache: bool = True,
cache_dir: Optional[str] = None,
load_from_disk: bool = True,
) -> None:
"""Initialize the cached Flux text encoder step.
Args:
use_cache: Whether to enable caching of prompt embeddings. Defaults to True.
cache_dir: Directory to store cache files. If None, uses ~/.cache/flux_prompt_cache.
load_from_disk: Whether to load existing cache from disk on initialization. Defaults to True.
"""
super().__init__()
self.cache = {} if use_cache else None
if use_cache:
self.cache_dir = cache_dir or os.path.join(
os.path.expanduser("~"), ".cache", "flux_prompt_cache"
)
os.makedirs(self.cache_dir, exist_ok=True)
else:
self.cache_dir = None
# Load existing cache if requested
if load_from_disk and use_cache:
self.load_cache_from_disk()
@property
def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the video generation"
@property
def expected_components(self):
return [
ComponentSpec("text_encoder", CLIPTextModel),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("text_encoder_2", T5EncoderModel),
ComponentSpec("tokenizer_2", T5TokenizerFast),
]
@property
def expected_configs(self):
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("joint_attention_kwargs"),
]
@property
def intermediate_outputs(self):
return [
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"text_ids",
type_hint=torch.Tensor,
description="ids from the text sequence for RoPE",
),
]
@staticmethod
def check_inputs(block_state):
for prompt in [block_state.prompt, block_state.prompt_2]:
if prompt is not None and (
not isinstance(prompt, str) and not isinstance(prompt, list)
):
raise ValueError(
f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}"
)
def save_cache_to_disk(self):
"""Save the current cache to disk as a safetensors file."""
if not self.cache or not self.cache_dir:
return
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
# Prepare tensors dict for safetensors
tensors_to_save = {}
for key, tensor in self.cache.items():
# Ensure tensor is on CPU before saving
cpu_tensor = (
tensor.cpu() if tensor.device != torch.device("cpu") else tensor
)
tensors_to_save[key] = cpu_tensor
# Save tensors
save_file(tensors_to_save, cache_file)
logger.info(f"Saved {len(tensors_to_save)} cached embeddings to {cache_file}")
def load_cache_from_disk(self):
"""Load cache from disk using memory-mapped safetensors."""
if not self.cache_dir or self.cache is None:
return
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
if not os.path.exists(cache_file):
return
try:
# Open safetensors file in context manager
with safe_open(cache_file, framework="pt", device="cpu") as f:
loaded_count = 0
for key in f.keys():
self.cache[key] = f.get_tensor(key)
loaded_count += 1
logger.debug(
f"Loaded {loaded_count} cached embeddings from {cache_file} (memory-mapped)"
)
except Exception as e:
logger.warning(f"Failed to load cache from disk: {e}")
def clear_cache_from_disk(self):
"""Clear cached safetensors file from disk."""
if not self.cache_dir:
return
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
if os.path.exists(cache_file):
os.remove(cache_file)
logger.info(f"Cleared cache file: {cache_file}")
# Also clear the in-memory cache
if self.cache:
self.cache.clear()
def get_cache_size(self):
"""Get the current cache size in MB."""
if not self.cache_dir:
return 0
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
if os.path.exists(cache_file):
return os.path.getsize(cache_file) / (1024 * 1024) # Convert to MB
return 0
@staticmethod
def _to_cache_key(prompt: str) -> str:
"""Generate a hash key for a single prompt string."""
return hashlib.sha256(prompt.encode()).hexdigest()
@staticmethod
def _get_cached_prompt_embeds(prompts, cache_instance, cache_suffix, device=None):
"""Split prompts into cached and new, returning indices for reconstruction.
Args:
prompts: List of prompt strings to check against cache.
cache_instance: CachedFluxTextEncoderStep instance with cache, or None.
cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip").
device: Optional device to move cached tensors to.
Returns:
tuple: (cached_embeds, prompts_to_encode, prompt_indices)
- cached_embeds: List of (idx, embedding) tuples for cached prompts
- prompts_to_encode: List of prompts that need encoding
- prompt_indices: List of original indices for prompts_to_encode
"""
cached_embeds = []
prompts_to_encode = []
prompt_indices = []
for idx, prompt in enumerate(prompts):
cache_key = CachedFluxTextEncoderStep._to_cache_key(prompt + cache_suffix)
if (
cache_instance
and cache_instance.cache
and cache_key in cache_instance.cache
):
cached_tensor = cache_instance.cache[cache_key]
# Move tensor to the correct device if specified
if device is not None and cached_tensor.device != device:
cached_tensor = cached_tensor.to(device)
cached_embeds.append((idx, cached_tensor))
else:
prompts_to_encode.append(prompt)
prompt_indices.append(idx)
return cached_embeds, prompts_to_encode, prompt_indices
@staticmethod
def _cache_prompt_embeds(
prompts, prompt_indices, prompt_embeds, cache_instance, cache_suffix
):
"""Store newly computed embeddings in cache and save to disk.
Args:
prompts: Original full list of prompts.
prompt_indices: Indices of newly encoded prompts in the original list.
prompt_embeds: Newly computed embeddings tensor.
cache_instance: CachedFluxTextEncoderStep instance with cache, or None.
cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip").
"""
if not cache_instance or cache_instance.cache is None:
return
for i, idx in enumerate(prompt_indices):
cache_key = CachedFluxTextEncoderStep._to_cache_key(
prompts[idx] + cache_suffix
)
# Store in memory cache on CPU to save GPU memory
tensor_slice = prompt_embeds[i : i + 1]
cache_instance.cache[cache_key] = tensor_slice
# Save updated cache to disk
cache_instance.save_cache_to_disk()
@staticmethod
def _merge_cached_prompt_embeds(
cached_embeds, prompt_indices, prompt_embeds, batch_size
):
"""Merge cached and newly computed embeddings back into original batch order.
Args:
cached_embeds: List of (idx, embedding) tuples from cache.
prompt_indices: Indices where new embeddings should be placed.
prompt_embeds: Newly computed embeddings tensor, or None if all cached.
batch_size: Total batch size for output tensor.
Returns:
torch.Tensor: Combined embeddings tensor in correct batch order.
"""
all_embeds = [None] * batch_size
# Place cached embeddings
for idx, embed in cached_embeds:
all_embeds[idx] = embed
# Place new embeddings
if prompt_embeds is not None:
for i, idx in enumerate(prompt_indices):
all_embeds[idx] = prompt_embeds[i : i + 1]
return torch.cat(all_embeds, dim=0)
@staticmethod
def _get_t5_prompt_embeds(
components,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: torch.device = None,
cache_instance=None,
):
"""Encode prompts using T5 text encoder with caching support.
Args:
components: Pipeline components containing T5 encoder and tokenizer.
prompt: Prompt(s) to encode.
num_images_per_prompt: Number of images per prompt for duplication.
max_sequence_length: Maximum sequence length for tokenization.
device: Device to place tensors on.
cache_instance: CachedFluxTextEncoderStep instance for caching, or None.
Returns:
torch.Tensor: T5 prompt embeddings ready for diffusion model.
"""
dtype = components.text_encoder_2.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
cached_embeds, prompts_to_encode, prompt_indices = (
CachedFluxTextEncoderStep._get_cached_prompt_embeds(
prompt, cache_instance, "_t5", device
)
)
if not prompts_to_encode:
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds, prompt_indices, None, batch_size
)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
return prompt_embeds
if isinstance(components, TextualInversionLoaderMixin):
prompts_to_encode = components.maybe_convert_prompt(
prompts_to_encode, components.tokenizer_2
)
text_inputs = components.tokenizer_2(
prompts_to_encode,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
# Check for truncation
untruncated_ids = components.tokenizer_2(
prompts_to_encode, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = components.tokenizer_2.batch_decode(
untruncated_ids[:, max_sequence_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = components.text_encoder_2(
text_input_ids.to(device), output_hidden_states=False
)[0]
CachedFluxTextEncoderStep._cache_prompt_embeds(
prompt, prompt_indices, prompt_embeds, cache_instance, "_t5"
)
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds, prompt_indices, prompt_embeds, batch_size
)
_, seq_len, _ = prompt_embeds.shape
# Duplicate for num_images_per_prompt
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
@staticmethod
def _get_clip_prompt_embeds(
components,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
device: torch.device = None,
cache_instance=None,
):
"""Encode prompts using CLIP text encoder with caching support.
Args:
components: Pipeline components containing CLIP encoder and tokenizer.
prompt: Prompt(s) to encode.
num_images_per_prompt: Number of images per prompt for duplication.
device: Device to place tensors on.
cache_instance: CachedFluxTextEncoderStep instance for caching, or None.
Returns:
torch.Tensor: CLIP pooled prompt embeddings ready for diffusion model.
"""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# Split cached and new prompts
cached_embeds, prompts_to_encode, prompt_indices = (
CachedFluxTextEncoderStep._get_cached_prompt_embeds(
prompt, cache_instance, "_clip", device
)
)
# Early return if all prompts are cached
if not prompts_to_encode:
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds, prompt_indices, None, batch_size
)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
if prompts_to_encode:
if isinstance(components, TextualInversionLoaderMixin):
prompts_to_encode = components.maybe_convert_prompt(
prompts_to_encode, components.tokenizer
)
text_inputs = components.tokenizer(
prompts_to_encode,
padding="max_length",
max_length=components.tokenizer.model_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
tokenizer_max_length = components.tokenizer.model_max_length
untruncated_ids = components.tokenizer(
prompts_to_encode, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = components.tokenizer.batch_decode(
untruncated_ids[:, tokenizer_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = components.text_encoder(
text_input_ids.to(device), output_hidden_states=False
)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(
dtype=components.text_encoder.dtype, device=device
)
# Cache the new embeddings
CachedFluxTextEncoderStep._cache_prompt_embeds(
prompt, prompt_indices, prompt_embeds, cache_instance, "_clip"
)
# Combine cached and newly encoded embeddings in correct order
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds,
prompt_indices,
prompt_embeds if prompts_to_encode else None,
batch_size,
)
# Duplicate for num_images_per_prompt
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
@staticmethod
def encode_prompt(
components,
prompt: Union[str, List[str]] = None,
prompt_2: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
cache_instance: Optional["CachedFluxTextEncoderStep"] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if components.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(components.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = CachedFluxTextEncoderStep._get_clip_prompt_embeds(
components,
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
cache_instance=cache_instance,
)
prompt_embeds = CachedFluxTextEncoderStep._get_t5_prompt_embeds(
components,
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
cache_instance=cache_instance,
)
if components.text_encoder is not None:
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
dtype = (
components.text_encoder.dtype
if components.text_encoder is not None
else torch.bfloat16
)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
@torch.no_grad()
def __call__(
self, components: FluxModularPipeline, state: PipelineState
) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
# Encode input prompt
block_state.text_encoder_lora_scale = (
block_state.joint_attention_kwargs.get("scale", None)
if block_state.joint_attention_kwargs is not None
else None
)
(
block_state.prompt_embeds,
block_state.pooled_prompt_embeds,
block_state.text_ids,
) = self.encode_prompt(
components,
prompt=block_state.prompt,
prompt_2=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
device=block_state.device,
num_images_per_prompt=1, # TODO: hardcoded for now.
max_sequence_length=512,
lora_scale=block_state.text_encoder_lora_scale,
cache_instance=self
if self.cache is not None
else None, # Pass self as cache_instance
)
# Add outputs
self.set_block_state(state, block_state)
return components, state
|