Spaces:
Running on Zero
Running on Zero
| """SD Tokenizer for text embedding.""" | |
| import logging | |
| import os | |
| import traceback | |
| import torch | |
| from transformers import CLIPTokenizerFast | |
| def model_options_long_clip(sd, tokenizer_data, model_options): | |
| """Handle long CLIP models.""" | |
| return tokenizer_data, model_options | |
| def parse_parentheses(string): | |
| """Parse nested parentheses into list.""" | |
| result, current, level = [], "", 0 | |
| for char in string: | |
| if char == "(": | |
| if level == 0 and current: | |
| result.append(current) | |
| current = "(" | |
| elif level == 0: | |
| current = "(" | |
| else: | |
| current += char | |
| level += 1 | |
| elif char == ")": | |
| level -= 1 | |
| if level == 0: | |
| result.append(current + ")") | |
| current = "" | |
| else: | |
| current += char | |
| else: | |
| current += char | |
| if current: | |
| result.append(current) | |
| return result | |
| def token_weights(string, weight=1.0): | |
| """Parse string into tokens with weights.""" | |
| out = [] | |
| for x in parse_parentheses(string): | |
| w = weight | |
| if len(x) >= 2 and x[-1] == ")" and x[0] == "(": | |
| x, w = x[1:-1], weight * 1.1 | |
| if (xx := x.rfind(":")) > 0: | |
| try: | |
| w, x = float(x[xx + 1:]), x[:xx] | |
| except ValueError: | |
| pass | |
| out += token_weights(x, w) | |
| else: | |
| out.append((x, weight)) | |
| return out | |
| def escape_important(text): | |
| return text.replace("\\)", "\0\1").replace("\\(", "\0\2") | |
| def unescape_important(text): | |
| return text.replace("\0\1", ")").replace("\0\2", "(") | |
| def expand_directory_list(directories): | |
| """Expand directories to include subdirectories.""" | |
| dirs = set(directories) | |
| for x in directories: | |
| for root, _, _ in os.walk(x, followlinks=True): | |
| dirs.add(root) | |
| return list(dirs) | |
| def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None): | |
| """Load embedding from directory.""" | |
| if isinstance(embedding_directory, str): | |
| embedding_directory = [embedding_directory] | |
| embedding_directory = expand_directory_list(embedding_directory) | |
| valid_file = None | |
| for embed_dir in embedding_directory: | |
| embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name)) | |
| embed_dir = os.path.abspath(embed_dir) | |
| try: | |
| if os.path.commonpath((embed_dir, embed_path)) != embed_dir: | |
| continue | |
| except Exception: | |
| continue | |
| if os.path.isfile(embed_path): | |
| valid_file = embed_path | |
| else: | |
| for ext in [".safetensors", ".pt", ".bin"]: | |
| if os.path.isfile(embed_path + ext): | |
| valid_file = embed_path + ext | |
| break | |
| if valid_file: | |
| break | |
| if not valid_file: | |
| return None | |
| try: | |
| if valid_file.lower().endswith(".safetensors"): | |
| import safetensors.torch | |
| embed = safetensors.torch.load_file(valid_file, device="cpu") | |
| else: | |
| embed = torch.load(valid_file, weights_only=True, map_location="cpu") | |
| except Exception: | |
| logging.warning(f"{traceback.format_exc()}\n\nerror loading embedding: {embedding_name}") | |
| return None | |
| if "string_to_param" in embed: | |
| return next(iter(embed["string_to_param"].values())) | |
| if isinstance(embed, list): | |
| out_list = [t.reshape(-1, t.shape[-1]) for x in embed for k, t in x.items() if t.shape[-1] == embedding_size] | |
| return torch.cat(out_list, dim=0) if out_list else None | |
| if embed_key and embed_key in embed: | |
| return embed[embed_key] | |
| return next(iter(embed.values())) | |
| class SDTokenizer: | |
| """Stable Diffusion tokenizer.""" | |
| def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, | |
| embedding_directory=None, embedding_size=768, embedding_key="clip_l", | |
| tokenizer_class=CLIPTokenizerFast, has_start_token=True, | |
| pad_to_max_length=True, min_length=None): | |
| self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path or "include/sd1_tokenizer/") | |
| self.max_length = max_length | |
| self.min_length = min_length | |
| self.pad_with_end = pad_with_end | |
| self.pad_to_max_length = pad_to_max_length | |
| self.embedding_directory = embedding_directory | |
| self.embedding_size = embedding_size | |
| self.embedding_key = embedding_key | |
| self.max_word_length = 8 | |
| self.embedding_identifier = "embedding:" | |
| empty = self.tokenizer("")["input_ids"] | |
| self.tokens_start = 1 if has_start_token else 0 | |
| self.start_token = empty[0] if has_start_token else None | |
| self.end_token = empty[1] if has_start_token else empty[0] | |
| self.inv_vocab = {v: k for k, v in self.tokenizer.get_vocab().items()} | |
| def _try_get_embedding(self, name): | |
| embed = load_embed(name, self.embedding_directory, self.embedding_size, self.embedding_key) | |
| if embed is None and (stripped := name.strip(",")) != name: | |
| embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key) | |
| return embed, name[len(stripped):] | |
| return embed, "" | |
| def tokenize_with_weights(self, text, return_word_ids=False): | |
| pad_token = self.end_token if self.pad_with_end else 0 | |
| parsed = token_weights(escape_important(text), 1.0) | |
| tokens = [] | |
| for segment, weight in parsed: | |
| for word in unescape_important(segment).replace("\n", " ").split(): | |
| if word.startswith(self.embedding_identifier) and self.embedding_directory: | |
| name = word[len(self.embedding_identifier):].strip("\n") | |
| embed, leftover = self._try_get_embedding(name) | |
| if embed is None: | |
| logging.warning(f"embedding:{name} does not exist") | |
| else: | |
| tokens.append([(embed[i], weight) for i in range(embed.shape[0])] if len(embed.shape) > 1 else [(embed, weight)]) | |
| print("loading", name) | |
| if leftover: | |
| word = leftover | |
| else: | |
| continue | |
| tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) | |
| batched = [] | |
| batch = [(self.start_token, 1.0, 0)] if self.start_token else [] | |
| batched.append(batch) | |
| for i, t_group in enumerate(tokens): | |
| is_large = len(t_group) >= self.max_word_length | |
| while t_group: | |
| if len(t_group) + len(batch) > self.max_length - 1: | |
| remaining = self.max_length - len(batch) - 1 | |
| if is_large: | |
| batch.extend([(t, w, i + 1) for t, w in t_group[:remaining]]) | |
| batch.append((self.end_token, 1.0, 0)) | |
| t_group = t_group[remaining:] | |
| else: | |
| batch.append((self.end_token, 1.0, 0)) | |
| if self.pad_to_max_length: | |
| batch.extend([(pad_token, 1.0, 0)] * remaining) | |
| batch = [(self.start_token, 1.0, 0)] if self.start_token else [] | |
| batched.append(batch) | |
| else: | |
| batch.extend([(t, w, i + 1) for t, w in t_group]) | |
| t_group = [] | |
| batch.append((self.end_token, 1.0, 0)) | |
| if self.pad_to_max_length: | |
| batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) | |
| if self.min_length and len(batch) < self.min_length: | |
| batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch))) | |
| return batched if return_word_ids else [[(t, w) for t, w, _ in x] for x in batched] | |
| def untokenize(self, pairs): | |
| return [(a, self.inv_vocab[a[0]]) for a in pairs] | |
| class SD1Tokenizer: | |
| """SD1 Tokenizer wrapper.""" | |
| def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer): | |
| self.clip_name = clip_name | |
| self.clip = f"clip_{clip_name}" | |
| setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory)) | |
| def tokenize_with_weights(self, text, return_word_ids=False): | |
| return {self.clip_name: getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)} | |
| def untokenize(self, pairs): | |
| return getattr(self, self.clip).untokenize(pairs) | |