import torch import math import re import os from .block_lora import lbw_lora from typing import Union, Optional, List, Tuple from abc import ABC class BaseTextualInversionManager(ABC): def expand_textual_inversion_token_ids_if_necessary(self, token_ids: List[int]) -> List[int]: raise NotImplementedError() class DiffusersTextualInversionManager(BaseTextualInversionManager): """ A textual inversion manager for use with diffusers. """ def __init__(self, pipe): self.pipe = pipe def expand_textual_inversion_token_ids_if_necessary(self, token_ids: List[int]) -> List[int]: if len(token_ids) == 0: return token_ids prompt = self.pipe.tokenizer.decode(token_ids) prompt = self.pipe.maybe_convert_prompt(prompt, self.pipe.tokenizer) return self.pipe.tokenizer.encode(prompt, add_special_tokens=False) re_attention = re.compile(r""" \\\(| \\\{| \\\)| \\\}| \\\[| \\]| \\\\| \\| \(| \{| \[| :([+-]?[.\d]+)\)| \)| \}| ]| [^\\()\\{}\[\]:]+| : """, re.X) re_AND = re.compile(r"\bAND\b") re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) def parse_prompt_attention(text): """ Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight. Accepted tokens are: (abc) - increases attention to abc by a multiplier of 1.1 (abc:3.12) - increases attention to abc by a multiplier of 3.12 [abc] - decreases attention to abc by a multiplier of 1.1 \( - literal character '(' \[ - literal character '[' \) - literal character ')' \] - literal character ']' \\ - literal character '\' anything else - just text >>> parse_prompt_attention('normal text') [['normal text', 1.0]] >>> parse_prompt_attention('an (important) word') [['an ', 1.0], ['important', 1.1], [' word', 1.0]] >>> parse_prompt_attention('(unbalanced') [['unbalanced', 1.1]] >>> parse_prompt_attention('\(literal\]') [['(literal]', 1.0]] >>> parse_prompt_attention('(unnecessary)(parens)') [['unnecessaryparens', 1.1]] >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') [['a ', 1.0], ['house', 1.5730000000000004], [' ', 1.1], ['on', 1.0], [' a ', 1.1], ['hill', 0.55], [', sun, ', 1.1], ['sky', 1.4641000000000006], ['.', 1.1]] """ res = [] round_brackets = [] square_brackets = [] round_bracket_multiplier = 1.1 square_bracket_multiplier = 1 / 1.1 def multiply_range(start_position, multiplier): for p in range(start_position, len(res)): res[p][1] *= multiplier for m in re_attention.finditer(text): text = m.group(0) weight = m.group(1) if text.startswith('\\'): res.append([text[1:], 1.0]) elif text == '(': round_brackets.append(len(res)) elif text == '[': square_brackets.append(len(res)) elif weight is not None and round_brackets: multiply_range(round_brackets.pop(), float(weight)) elif text == ')' and round_brackets: multiply_range(round_brackets.pop(), round_bracket_multiplier) elif text == ']' and square_brackets: multiply_range(square_brackets.pop(), square_bracket_multiplier) else: parts = re.split(re_break, text) for i, part in enumerate(parts): if i > 0: res.append(["BREAK", -1]) res.append([part, 1.0]) for pos in round_brackets: multiply_range(pos, round_bracket_multiplier) for pos in square_brackets: multiply_range(pos, square_bracket_multiplier) if len(res) == 0: res = [["", 1.0]] # merge runs of identical weights i = 0 while i + 1 < len(res): if res[i][1] == res[i + 1][1]: res[i][0] += res[i + 1][0] res.pop(i + 1) else: i += 1 return res class CLIPMultiTextCustomEmbedder(object): def __init__(self, tokenizers, text_encoders, device, textual_inversion_manager: BaseTextualInversionManager, clip_stop_at_last_layers=1, requires_pooled=False): self.tokenizer = tokenizers[0] self.tokenizer_2 = tokenizers[1] self.text_encoder = text_encoders[0] self.text_encoder_2 = text_encoders[1] self.textual_inversion_manager = textual_inversion_manager self.token_mults = {} self.device = device self.clip_stop_at_last_layers = clip_stop_at_last_layers self.requires_pooled=requires_pooled def tokenize_line(self, line): def get_target_prompt_token_count(token_count): return math.ceil(max(token_count, 1) / 75) * 75 id_start = self.tokenizer.bos_token_id id_end = self.tokenizer.eos_token_id parsed = parse_prompt_attention(line) tokenized = self.tokenizer( [text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] fixes = [] remade_tokens = [] multipliers = [] for tokens, (text, weight) in zip(tokenized, parsed): i = 0 while i < len(tokens): token = tokens[i] remade_tokens.append(token) multipliers.append(weight) i += 1 token_count = len(remade_tokens) prompt_target_length = get_target_prompt_token_count(token_count) tokens_to_add = prompt_target_length - len(remade_tokens) remade_tokens = remade_tokens + [id_end] * tokens_to_add multipliers = multipliers + [1.0] * tokens_to_add return remade_tokens, fixes, multipliers, token_count def process_text(self, texts): if isinstance(texts, str): texts = [texts] remade_batch_tokens = [] cache = {} batch_multipliers = [] for line in texts: if line in cache: remade_tokens, multipliers = cache[line] else: remade_tokens, _, multipliers, _ = self.tokenize_line(line) cache[line] = (remade_tokens, multipliers) remade_batch_tokens.append(remade_tokens) batch_multipliers.append(multipliers) return batch_multipliers, remade_batch_tokens def __call__(self, text): batch_multipliers, remade_batch_tokens = self.process_text(text) z = None pooled = None i = 0 while max(map(len, remade_batch_tokens)) != 0: rem_tokens = [x[75:] for x in remade_batch_tokens] rem_multipliers = [x[75:] for x in batch_multipliers] tokens = [] multipliers = [] for j in range(len(remade_batch_tokens)): if len(remade_batch_tokens[j]) > 0: tokens.append(remade_batch_tokens[j][:75]) multipliers.append(batch_multipliers[j][:75]) else: tokens.append([self.tokenizer.eos_token_id] * 75) multipliers.append([1.0] * 75) z1, pooly = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) pooled = pooly if pooled is None else torch.cat((pooled, pooly), axis=-2) remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 return [z, pooled] def process_tokens(self, remade_batch_tokens, batch_multipliers): remade_batch_tokens = [[self.tokenizer.bos_token_id] + x[:75] + [self.tokenizer.eos_token_id] for x in remade_batch_tokens] batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] tokens = torch.asarray(remade_batch_tokens).to(self.device) # print(tokens.shape) # print(tokens) encoder=[self.text_encoder,self.text_encoder_2] plist=[] batch_multipliers_of_same_length = [ x + [1.0] * (75 - len(x)) for x in batch_multipliers] batch_multipliers = torch.asarray( batch_multipliers_of_same_length).to(self.device) for text_encoder in encoder: output = text_encoder(input_ids=tokens,output_hidden_states=True) pooled=output[0] if self.clip_stop_at_last_layers > 1: z = output.hidden_states[-(2+self.clip_stop_at_last_layers)] else: z = output.hidden_states[-2] z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) plist.append(z) z = torch.concat(plist, dim=-1) return z, pooled def get_text_tokens(self, text): batch_multipliers, remade_batch_tokens = self.process_text(text) return [[self.tokenizer.bos_token_id] + remade_batch_tokens[0]], \ [[1.0] + batch_multipliers[0]] def get_token_ids(self, texts: List[str]) -> List[List[int]]: token_ids_list = self.tokenizer( texts, truncation=True, padding="max_length", return_tensors=None, # just give me lists of ints )['input_ids'] result = [] for token_ids in token_ids_list: # trim eos/bos token_ids = token_ids[1:-1] # pad for textual inversions with vector length >1 if self.textual_inversion_manager is not None: token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids) token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id] result.append(token_ids) return result def get_pooled_embeddings(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None, device: Optional[str]=None) -> Optional[torch.Tensor]: device = device or self.device token_ids = self.get_token_ids(texts) token_ids = torch.tensor(token_ids, dtype=torch.long).to(device) text_encoder_output = self.text_encoder(token_ids, attention_mask, return_dict=True) pooled = [text_encoder_output.pooler_output] pooled = torch.cat(pooled, dim=-1) return pooled class CLIPTextCustomEmbedder(object): def __init__(self, tokenizer, text_encoder, device, textual_inversion_manager: BaseTextualInversionManager, clip_stop_at_last_layers=1, requires_pooled=False): self.tokenizer = tokenizer self.text_encoder = text_encoder self.textual_inversion_manager = textual_inversion_manager self.token_mults = {} self.device = device self.clip_stop_at_last_layers = clip_stop_at_last_layers self.requires_pooled=requires_pooled def tokenize_line(self, line): def get_target_prompt_token_count(token_count): return math.ceil(max(token_count, 1) / 75) * 75 id_end = self.tokenizer.eos_token_id parsed = parse_prompt_attention(line) tokenized = self.tokenizer( [text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] fixes = [] remade_tokens = [] multipliers = [] for tokens, (text, weight) in zip(tokenized, parsed): i = 0 while i < len(tokens): token = tokens[i] remade_tokens.append(token) multipliers.append(weight) i += 1 token_count = len(remade_tokens) prompt_target_length = get_target_prompt_token_count(token_count) tokens_to_add = prompt_target_length - len(remade_tokens) remade_tokens = remade_tokens + [id_end] * tokens_to_add multipliers = multipliers + [1.0] * tokens_to_add return remade_tokens, fixes, multipliers, token_count def process_text(self, texts): if isinstance(texts, str): texts = [texts] remade_batch_tokens = [] cache = {} batch_multipliers = [] for line in texts: if line in cache: remade_tokens, fixes, multipliers = cache[line] else: remade_tokens, fixes, multipliers, _ = self.tokenize_line(line) cache[line] = (remade_tokens, fixes, multipliers) remade_batch_tokens.append(remade_tokens) batch_multipliers.append(multipliers) return batch_multipliers, remade_batch_tokens def __call__(self, text): batch_multipliers, remade_batch_tokens = self.process_text(text) z = None i = 0 while max(map(len, remade_batch_tokens)) != 0: rem_tokens = [x[75:] for x in remade_batch_tokens] rem_multipliers = [x[75:] for x in batch_multipliers] tokens = [] multipliers = [] for j in range(len(remade_batch_tokens)): if len(remade_batch_tokens[j]) > 0: tokens.append(remade_batch_tokens[j][:75]) multipliers.append(batch_multipliers[j][:75]) else: tokens.append([self.tokenizer.eos_token_id] * 75) multipliers.append([1.0] * 75) z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 return z def process_tokens(self, remade_batch_tokens, batch_multipliers): remade_batch_tokens = [[self.tokenizer.bos_token_id] + x[:75] + [self.tokenizer.eos_token_id] for x in remade_batch_tokens] batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] tokens = torch.asarray(remade_batch_tokens).to(self.device) # print(tokens.shape) # print(tokens) outputs = self.text_encoder( input_ids=tokens, output_hidden_states=True) if self.clip_stop_at_last_layers > 1: z = self.text_encoder.text_model.final_layer_norm( outputs.hidden_states[-self.clip_stop_at_last_layers]) else: z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well # to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [ x + [1.0] * (75 - len(x)) for x in batch_multipliers] batch_multipliers = torch.asarray( batch_multipliers_of_same_length).to(self.device) # print(batch_multipliers.shape) # print(batch_multipliers) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() z *= original_mean / new_mean return z def get_text_tokens(self, text): batch_multipliers, remade_batch_tokens = self.process_text(text) return [[self.tokenizer.bos_token_id] + remade_batch_tokens[0]], \ [[1.0] + batch_multipliers[0]] def get_token_ids(self, texts: List[str]) -> List[List[int]]: token_ids_list = self.tokenizer( texts, truncation=True, padding="max_length", return_tensors=None, # just give me lists of ints )['input_ids'] result = [] for token_ids in token_ids_list: # trim eos/bos token_ids = token_ids[1:-1] # pad for textual inversions with vector length >1 if self.textual_inversion_manager is not None: token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids) # add back eos/bos if requested if include_start_and_end_markers: token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id] result.append(token_ids) return result def get_pooled_embeddings(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None, device: Optional[str]=None) -> Optional[torch.Tensor]: device = device or self.device token_ids = self.get_token_ids(texts) token_ids = torch.tensor(token_ids, dtype=torch.long).to(device) text_encoder_output = self.text_encoder(token_ids, attention_mask, return_dict=True) pooled = text_encoder_output.pooler_output return pooled def text_embeddings_equal_len(text_embedder, prompt, negative_prompt) -> List[torch.Tensor]: conds = text_embedder(prompt) unconds = text_embedder(negative_prompt) if text_embedder.requires_pooled: conditionings = [conds[0], unconds[0]] pooled = [text_embedder.get_pooled_embeddings([prompt], device=text_embedder.device), text_embedder.get_pooled_embeddings([negative_prompt], device=text_embedder.device),] else: conditionings = [conds,unconds] pooled = None emptystring_conditioning = text_embedder("") if type(emptystring_conditioning) is tuple or type(emptystring_conditioning) is list: # discard pooled emptystring_conditioning = emptystring_conditioning[0] # ensure all conditioning tensors are 3 dimensions c0_shape = conditionings[0].shape if not all([c.shape[0] == c0_shape[0] and c.shape[2] == c0_shape[2] for c in conditionings]): raise ValueError(f"All conditioning tensors must have the same batch size ({c0_shape[0]}) and number of embeddings per token ({c0_shape[1]}") if len(emptystring_conditioning.shape) == 2: emptystring_conditioning = emptystring_conditioning.unsqueeze(0) empty_z = torch.cat([emptystring_conditioning] * c0_shape[0]) max_token_count = max([c.shape[1] for c in conditionings]) # if necessary, pad shorter tensors out with an emptystring tensor for i, c in enumerate(conditionings): while c.shape[1] < max_token_count: c = torch.cat([c, empty_z], dim=1) conditionings[i] = c return conditionings, pooled def text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=1, pool=False): if pool: text_embedder = CLIPMultiTextCustomEmbedder(tokenizers=[pipe.tokenizer,pipe.tokenizer_2], text_encoders=[pipe.text_encoder,pipe.text_encoder_2], device=pipe.text_encoder.device, textual_inversion_manager=DiffusersTextualInversionManager(pipe), clip_stop_at_last_layers=clip_stop_at_last_layers, requires_pooled=pool) else: text_embedder = CLIPTextCustomEmbedder(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, device=pipe.text_encoder.device, textual_inversion_manager=DiffusersTextualInversionManager(pipe), clip_stop_at_last_layers=clip_stop_at_last_layers, requires_pooled=pool) conds, pooled = text_embeddings_equal_len(text_embedder, prompt, negative_prompt) return conds, pooled def apply_embeddings(pipe, input_str,epath,pool): for name, path in epath.items(): new_string = input_str.replace(name,"") if new_string != input_str: if pool: from safetensors.torch import load_file state = load_file(path) pipe.load_textual_inversion(pretrained_model_name_or_path=state["clip_g"], token=name, text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2, local_files_only=True) pipe.load_textual_inversion(pretrained_model_name_or_path=state["clip_l"], token=name, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, local_files_only=True) else: pipe.load_textual_inversion(pretrained_model_name_or_path=path, token=name, local_files_only=True) return input_str def bpro(prompt): k = prompt.split(",") thu = [] for g in k: f = g.count(" ") thu.append([g, f+1]) off = 0 nl = [] t = 0 for x in thu: if "BREAK" in x[0]: tok = t+off add = tok % 75 nl += [" "]*add off += add continue t += x[1] nl.append(x[0]) return ",".join(nl) def lora_prompt(prompt, pipe, lpath): loras = [] adap_list=[] alphas=[] add = [] prompt=bpro(prompt) def network_replacement(m): alias = m.group(1) num = m.group(2) try: data = lpath[alias] mpath = data[1] dpath = data[0] alias = data[2] except: return "" if "|" in num: t = num.split("|") alpha = float(t[0]) apply = t[1] npath = f"{mpath}{alias}_{apply}.safetensors" try: data = lpath[f"{alias}_{apply}"] loras.append([data[0], alpha]) return alias except: lpath[f"{alias}_{apply}"] = [npath, dpath] lbw_lora(dpath, npath, apply) dpath = npath else: alpha = float(num) loras.append([dpath, alpha]) return alias re_lora = re.compile("") prompt = re.sub(re_lora, network_replacement, prompt) for k in loras: p = os.path.abspath(os.path.join(k[0], "..")) safe = os.path.basename(k[0]) name = os.path.splitext(safe)[0].replace(".","_") alphas.append(k[1]) adap_list.append(name) try: pipe.load_lora_weights(p, weight_name=safe, adapter_name=name) except: pass pipe.set_adapters(adap_list, adapter_weights=alphas) return prompt, lpath def create_conditioning(pipe, positive: str, negative: str, epath, lora_list, clip_skip = 1, pool = False): positive = apply_embeddings(pipe, positive, epath, pool) negative = apply_embeddings(pipe, negative, epath, pool) positive, lora_list = lora_prompt(positive, pipe, lora_list) conds, pooled = text_embeddings(pipe, positive, negative, clip_skip, pool) if pool: embed_dict={ 'prompt_embeds': conds[0], 'pooled_prompt_embeds':pooled[0], 'negative_prompt_embeds': conds[1], 'negative_pooled_prompt_embeds': pooled[1], } else: embed_dict={ 'prompt_embeds': conds[0], 'negative_prompt_embeds': conds[1], } return (embed_dict, lora_list)