| import os |
| from typing import Any, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection |
| from .strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy |
|
|
|
|
| from .utils import setup_logging |
|
|
| setup_logging() |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| TOKENIZER1_PATH = "openai/clip-vit-large-patch14" |
| TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" |
|
|
|
|
| class SdxlTokenizeStrategy(TokenizeStrategy): |
| def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: |
| self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir) |
| self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir) |
| self.tokenizer2.pad_token_id = 0 |
|
|
| if max_length is None: |
| self.max_length = self.tokenizer1.model_max_length |
| else: |
| self.max_length = max_length + 2 |
|
|
| def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: |
| text = [text] if isinstance(text, str) else text |
| return ( |
| torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0), |
| torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), |
| ) |
|
|
| def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: |
| text = [text] if isinstance(text, str) else text |
| tokens1_list, tokens2_list = [], [] |
| weights1_list, weights2_list = [], [] |
| for t in text: |
| tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True) |
| tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True) |
| tokens1_list.append(tokens1) |
| tokens2_list.append(tokens2) |
| weights1_list.append(weights1) |
| weights2_list.append(weights2) |
| return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [ |
| torch.stack(weights1_list, dim=0), |
| torch.stack(weights2_list, dim=0), |
| ] |
|
|
|
|
| class SdxlTextEncodingStrategy(TextEncodingStrategy): |
| def __init__(self) -> None: |
| pass |
|
|
| def _pool_workaround( |
| self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int |
| ): |
| r""" |
| workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output |
| instead of the hidden states for the EOS token |
| If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output |
| |
| Original code from CLIP's pooling function: |
| |
| \# text_embeds.shape = [batch_size, sequence_length, transformer.width] |
| \# take features from the eot embedding (eot_token is the highest number in each sequence) |
| \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 |
| pooled_output = last_hidden_state[ |
| torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), |
| input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), |
| ] |
| """ |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| eos_token_mask = (input_ids == eos_token_id).int() |
|
|
| |
| eos_token_index = torch.argmax(eos_token_mask, dim=1) |
| eos_token_index = eos_token_index.to(device=last_hidden_state.device) |
|
|
| |
| pooled_output = last_hidden_state[ |
| torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index |
| ] |
|
|
| |
| pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) |
| pooled_output = pooled_output.to(last_hidden_state.dtype) |
|
|
| return pooled_output |
|
|
| def _get_hidden_states_sdxl( |
| self, |
| input_ids1: torch.Tensor, |
| input_ids2: torch.Tensor, |
| tokenizer1: CLIPTokenizer, |
| tokenizer2: CLIPTokenizer, |
| text_encoder1: Union[CLIPTextModel, torch.nn.Module], |
| text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module], |
| unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None, |
| ): |
| |
| b_size = input_ids1.size()[0] |
| if input_ids1.size()[1] == 1: |
| max_token_length = None |
| else: |
| max_token_length = input_ids1.size()[1] * input_ids1.size()[2] |
| input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) |
| input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) |
| input_ids1 = input_ids1.to(text_encoder1.device) |
| input_ids2 = input_ids2.to(text_encoder2.device) |
|
|
| |
| enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) |
| hidden_states1 = enc_out["hidden_states"][11] |
|
|
| |
| enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) |
| hidden_states2 = enc_out["hidden_states"][-2] |
|
|
| |
| unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2 |
| pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) |
|
|
| |
| n_size = 1 if max_token_length is None else max_token_length // 75 |
| hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) |
| hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) |
|
|
| if max_token_length is not None: |
| |
| |
| states_list = [hidden_states1[:, 0].unsqueeze(1)] |
| for i in range(1, max_token_length, tokenizer1.model_max_length): |
| states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) |
| states_list.append(hidden_states1[:, -1].unsqueeze(1)) |
| hidden_states1 = torch.cat(states_list, dim=1) |
|
|
| |
| states_list = [hidden_states2[:, 0].unsqueeze(1)] |
| for i in range(1, max_token_length, tokenizer2.model_max_length): |
| chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] |
| |
| |
| |
| |
| |
| |
| states_list.append(chunk) |
| states_list.append(hidden_states2[:, -1].unsqueeze(1)) |
| hidden_states2 = torch.cat(states_list, dim=1) |
|
|
| |
| pool2 = pool2[::n_size] |
|
|
| return hidden_states1, hidden_states2, pool2 |
|
|
| def encode_tokens( |
| self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] |
| ) -> List[torch.Tensor]: |
| """ |
| Args: |
| tokenize_strategy: TokenizeStrategy |
| models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]. |
| If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required |
| tokens: List of tokens, for text_encoder1 and text_encoder2 |
| """ |
| if len(models) == 2: |
| text_encoder1, text_encoder2 = models |
| unwrapped_text_encoder2 = None |
| else: |
| text_encoder1, text_encoder2, unwrapped_text_encoder2 = models |
| tokens1, tokens2 = tokens |
| sdxl_tokenize_strategy = tokenize_strategy |
| tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2 |
|
|
| hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl( |
| tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2 |
| ) |
| return [hidden_states1, hidden_states2, pool2] |
|
|
| def encode_tokens_with_weights( |
| self, |
| tokenize_strategy: TokenizeStrategy, |
| models: List[Any], |
| tokens_list: List[torch.Tensor], |
| weights_list: List[torch.Tensor], |
| ) -> List[torch.Tensor]: |
| hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list) |
|
|
| weights_list = [weights.to(hidden_states1.device) for weights in weights_list] |
|
|
| |
| if weights_list[0].shape[1] == 1: |
| |
| hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2) |
| hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2) |
| else: |
| |
| for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]): |
| for i in range(weight.shape[1]): |
| hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[ |
| :, i, 1:-1 |
| ].unsqueeze(-1) |
|
|
| return [hidden_states1, hidden_states2, pool2] |
|
|
|
|
| class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): |
| SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" |
|
|
| def __init__( |
| self, |
| cache_to_disk: bool, |
| batch_size: int, |
| skip_disk_cache_validity_check: bool, |
| is_partial: bool = False, |
| is_weighted: bool = False, |
| ) -> None: |
| super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted) |
|
|
| def get_outputs_npz_path(self, image_abs_path: str) -> str: |
| return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX |
|
|
| def is_disk_cached_outputs_expected(self, npz_path: str): |
| if not self.cache_to_disk: |
| return False |
| if not os.path.exists(npz_path): |
| return False |
| if self.skip_disk_cache_validity_check: |
| return True |
|
|
| try: |
| npz = np.load(npz_path) |
| if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: |
| return False |
| except Exception as e: |
| logger.error(f"Error loading file: {npz_path}") |
| raise e |
|
|
| return True |
|
|
| def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: |
| data = np.load(npz_path) |
| hidden_state1 = data["hidden_state1"] |
| hidden_state2 = data["hidden_state2"] |
| pool2 = data["pool2"] |
| return [hidden_state1, hidden_state2, pool2] |
|
|
| def cache_batch_outputs( |
| self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List |
| ): |
| sdxl_text_encoding_strategy = text_encoding_strategy |
| captions = [info.caption for info in infos] |
|
|
| if self.is_weighted: |
| tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions) |
| with torch.no_grad(): |
| hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights( |
| tokenize_strategy, models, tokens_list, weights_list |
| ) |
| else: |
| tokens1, tokens2 = tokenize_strategy.tokenize(captions) |
| with torch.no_grad(): |
| hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( |
| tokenize_strategy, models, [tokens1, tokens2] |
| ) |
|
|
| if hidden_state1.dtype == torch.bfloat16: |
| hidden_state1 = hidden_state1.float() |
| if hidden_state2.dtype == torch.bfloat16: |
| hidden_state2 = hidden_state2.float() |
| if pool2.dtype == torch.bfloat16: |
| pool2 = pool2.float() |
|
|
| hidden_state1 = hidden_state1.cpu().numpy() |
| hidden_state2 = hidden_state2.cpu().numpy() |
| pool2 = pool2.cpu().numpy() |
|
|
| for i, info in enumerate(infos): |
| hidden_state1_i = hidden_state1[i] |
| hidden_state2_i = hidden_state2[i] |
| pool2_i = pool2[i] |
|
|
| if self.cache_to_disk: |
| np.savez( |
| info.text_encoder_outputs_npz, |
| hidden_state1=hidden_state1_i, |
| hidden_state2=hidden_state2_i, |
| pool2=pool2_i, |
| ) |
| else: |
| info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] |
|
|