Spaces:
Running on Zero
Running on Zero
File size: 8,654 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | """SD Tokenizer for text embedding."""
import logging
import os
import traceback
import torch
from transformers import CLIPTokenizerFast
def model_options_long_clip(sd, tokenizer_data, model_options):
"""Handle long CLIP models."""
return tokenizer_data, model_options
def parse_parentheses(string):
"""Parse nested parentheses into list."""
result, current, level = [], "", 0
for char in string:
if char == "(":
if level == 0 and current:
result.append(current)
current = "("
elif level == 0:
current = "("
else:
current += char
level += 1
elif char == ")":
level -= 1
if level == 0:
result.append(current + ")")
current = ""
else:
current += char
else:
current += char
if current:
result.append(current)
return result
def token_weights(string, weight=1.0):
"""Parse string into tokens with weights."""
out = []
for x in parse_parentheses(string):
w = weight
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
x, w = x[1:-1], weight * 1.1
if (xx := x.rfind(":")) > 0:
try:
w, x = float(x[xx + 1:]), x[:xx]
except ValueError:
pass
out += token_weights(x, w)
else:
out.append((x, weight))
return out
def escape_important(text):
return text.replace("\\)", "\0\1").replace("\\(", "\0\2")
def unescape_important(text):
return text.replace("\0\1", ")").replace("\0\2", "(")
def expand_directory_list(directories):
"""Expand directories to include subdirectories."""
dirs = set(directories)
for x in directories:
for root, _, _ in os.walk(x, followlinks=True):
dirs.add(root)
return list(dirs)
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
"""Load embedding from directory."""
if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]
embedding_directory = expand_directory_list(embedding_directory)
valid_file = None
for embed_dir in embedding_directory:
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
embed_dir = os.path.abspath(embed_dir)
try:
if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
continue
except Exception:
continue
if os.path.isfile(embed_path):
valid_file = embed_path
else:
for ext in [".safetensors", ".pt", ".bin"]:
if os.path.isfile(embed_path + ext):
valid_file = embed_path + ext
break
if valid_file:
break
if not valid_file:
return None
try:
if valid_file.lower().endswith(".safetensors"):
import safetensors.torch
embed = safetensors.torch.load_file(valid_file, device="cpu")
else:
embed = torch.load(valid_file, weights_only=True, map_location="cpu")
except Exception:
logging.warning(f"{traceback.format_exc()}\n\nerror loading embedding: {embedding_name}")
return None
if "string_to_param" in embed:
return next(iter(embed["string_to_param"].values()))
if isinstance(embed, list):
out_list = [t.reshape(-1, t.shape[-1]) for x in embed for k, t in x.items() if t.shape[-1] == embedding_size]
return torch.cat(out_list, dim=0) if out_list else None
if embed_key and embed_key in embed:
return embed[embed_key]
return next(iter(embed.values()))
class SDTokenizer:
"""Stable Diffusion tokenizer."""
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True,
embedding_directory=None, embedding_size=768, embedding_key="clip_l",
tokenizer_class=CLIPTokenizerFast, has_start_token=True,
pad_to_max_length=True, min_length=None):
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path or "include/sd1_tokenizer/")
self.max_length = max_length
self.min_length = min_length
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
self.embedding_directory = embedding_directory
self.embedding_size = embedding_size
self.embedding_key = embedding_key
self.max_word_length = 8
self.embedding_identifier = "embedding:"
empty = self.tokenizer("")["input_ids"]
self.tokens_start = 1 if has_start_token else 0
self.start_token = empty[0] if has_start_token else None
self.end_token = empty[1] if has_start_token else empty[0]
self.inv_vocab = {v: k for k, v in self.tokenizer.get_vocab().items()}
def _try_get_embedding(self, name):
embed = load_embed(name, self.embedding_directory, self.embedding_size, self.embedding_key)
if embed is None and (stripped := name.strip(",")) != name:
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
return embed, name[len(stripped):]
return embed, ""
def tokenize_with_weights(self, text, return_word_ids=False):
pad_token = self.end_token if self.pad_with_end else 0
parsed = token_weights(escape_important(text), 1.0)
tokens = []
for segment, weight in parsed:
for word in unescape_important(segment).replace("\n", " ").split():
if word.startswith(self.embedding_identifier) and self.embedding_directory:
name = word[len(self.embedding_identifier):].strip("\n")
embed, leftover = self._try_get_embedding(name)
if embed is None:
logging.warning(f"embedding:{name} does not exist")
else:
tokens.append([(embed[i], weight) for i in range(embed.shape[0])] if len(embed.shape) > 1 else [(embed, weight)])
print("loading", name)
if leftover:
word = leftover
else:
continue
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
batched = []
batch = [(self.start_token, 1.0, 0)] if self.start_token else []
batched.append(batch)
for i, t_group in enumerate(tokens):
is_large = len(t_group) >= self.max_word_length
while t_group:
if len(t_group) + len(batch) > self.max_length - 1:
remaining = self.max_length - len(batch) - 1
if is_large:
batch.extend([(t, w, i + 1) for t, w in t_group[:remaining]])
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining:]
else:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * remaining)
batch = [(self.start_token, 1.0, 0)] if self.start_token else []
batched.append(batch)
else:
batch.extend([(t, w, i + 1) for t, w in t_group])
t_group = []
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
return batched if return_word_ids else [[(t, w) for t, w, _ in x] for x in batched]
def untokenize(self, pairs):
return [(a, self.inv_vocab[a[0]]) for a in pairs]
class SD1Tokenizer:
"""SD1 Tokenizer wrapper."""
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = f"clip_{clip_name}"
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
def tokenize_with_weights(self, text, return_word_ids=False):
return {self.clip_name: getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)}
def untokenize(self, pairs):
return getattr(self, self.clip).untokenize(pairs)
|