| | import glob |
| | import os |
| | from typing import Any, List, Optional, Tuple, Union |
| |
|
| | import torch |
| | from transformers import CLIPTokenizer |
| | from . import train_util |
| | from .strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy |
| | from .utils import setup_logging |
| |
|
| | setup_logging() |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | TOKENIZER_ID = "openai/clip-vit-large-patch14" |
| | V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" |
| |
|
| |
|
| | class SdTokenizeStrategy(TokenizeStrategy): |
| | def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: |
| | """ |
| | max_length does not include <BOS> and <EOS> (None, 75, 150, 225) |
| | """ |
| | logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer") |
| | if v2: |
| | self.tokenizer = self._load_tokenizer( |
| | CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir |
| | ) |
| | else: |
| | self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) |
| |
|
| | if max_length is None: |
| | self.max_length = self.tokenizer.model_max_length |
| | else: |
| | self.max_length = max_length + 2 |
| |
|
| | def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: |
| | text = [text] if isinstance(text, str) else text |
| | return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] |
| |
|
| | def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: |
| | text = [text] if isinstance(text, str) else text |
| | tokens_list = [] |
| | weights_list = [] |
| | for t in text: |
| | tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) |
| | tokens_list.append(tokens) |
| | weights_list.append(weights) |
| | return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] |
| |
|
| |
|
| | class SdTextEncodingStrategy(TextEncodingStrategy): |
| | def __init__(self, clip_skip: Optional[int] = None) -> None: |
| | self.clip_skip = clip_skip |
| |
|
| | def encode_tokens( |
| | self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] |
| | ) -> List[torch.Tensor]: |
| | text_encoder = models[0] |
| | tokens = tokens[0] |
| | sd_tokenize_strategy = tokenize_strategy |
| |
|
| | |
| | b_size = tokens.size()[0] |
| | max_token_length = tokens.size()[1] * tokens.size()[2] |
| | model_max_length = sd_tokenize_strategy.tokenizer.model_max_length |
| | tokens = tokens.reshape((-1, model_max_length)) |
| |
|
| | tokens = tokens.to(text_encoder.device) |
| |
|
| | if self.clip_skip is None: |
| | encoder_hidden_states = text_encoder(tokens)[0] |
| | else: |
| | enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) |
| | encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] |
| | encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) |
| |
|
| | |
| | encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) |
| |
|
| | if max_token_length != model_max_length: |
| | v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id |
| | if not v1: |
| | |
| | states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] |
| | for i in range(1, max_token_length, model_max_length): |
| | chunk = encoder_hidden_states[:, i : i + model_max_length - 2] |
| | if i > 0: |
| | for j in range(len(chunk)): |
| | if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: |
| | |
| | chunk[j, 0] = chunk[j, 1] |
| | states_list.append(chunk) |
| | states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) |
| | encoder_hidden_states = torch.cat(states_list, dim=1) |
| | else: |
| | |
| | states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] |
| | for i in range(1, max_token_length, model_max_length): |
| | states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) |
| | states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) |
| | encoder_hidden_states = torch.cat(states_list, dim=1) |
| |
|
| | return [encoder_hidden_states] |
| |
|
| | def encode_tokens_with_weights( |
| | self, |
| | tokenize_strategy: TokenizeStrategy, |
| | models: List[Any], |
| | tokens_list: List[torch.Tensor], |
| | weights_list: List[torch.Tensor], |
| | ) -> List[torch.Tensor]: |
| | encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] |
| |
|
| | weights = weights_list[0].to(encoder_hidden_states.device) |
| |
|
| | |
| | if weights.shape[1] == 1: |
| | |
| | encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) |
| | else: |
| | |
| | for i in range(weights.shape[1]): |
| | encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ |
| | :, i, 1:-1 |
| | ].unsqueeze(-1) |
| |
|
| | return [encoder_hidden_states] |
| |
|
| |
|
| | class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): |
| | |
| | |
| |
|
| | SD_OLD_LATENTS_NPZ_SUFFIX = ".npz" |
| | SD_LATENTS_NPZ_SUFFIX = "_sd.npz" |
| | SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz" |
| |
|
| | def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: |
| | super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) |
| | self.sd = sd |
| | self.suffix = ( |
| | SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX |
| | ) |
| | |
| | @property |
| | def cache_suffix(self) -> str: |
| | return self.suffix |
| |
|
| | def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: |
| | |
| | old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX |
| | if os.path.exists(old_npz_file): |
| | return old_npz_file |
| | return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix |
| |
|
| | def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): |
| | return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) |
| |
|
| | |
| | def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): |
| | encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample() |
| | vae_device = vae.device |
| | vae_dtype = vae.dtype |
| |
|
| | self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) |
| |
|
| | if not train_util.HIGH_VRAM: |
| | train_util.clean_memory_on_device(vae.device) |
| |
|