File size: 8,654 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""SD Tokenizer for text embedding."""
import logging
import os
import traceback
import torch
from transformers import CLIPTokenizerFast


def model_options_long_clip(sd, tokenizer_data, model_options):
    """Handle long CLIP models."""
    return tokenizer_data, model_options


def parse_parentheses(string):
    """Parse nested parentheses into list."""
    result, current, level = [], "", 0
    for char in string:
        if char == "(":
            if level == 0 and current:
                result.append(current)
                current = "("
            elif level == 0:
                current = "("
            else:
                current += char
            level += 1
        elif char == ")":
            level -= 1
            if level == 0:
                result.append(current + ")")
                current = ""
            else:
                current += char
        else:
            current += char
    if current:
        result.append(current)
    return result


def token_weights(string, weight=1.0):
    """Parse string into tokens with weights."""
    out = []
    for x in parse_parentheses(string):
        w = weight
        if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
            x, w = x[1:-1], weight * 1.1
            if (xx := x.rfind(":")) > 0:
                try:
                    w, x = float(x[xx + 1:]), x[:xx]
                except ValueError:
                    pass
            out += token_weights(x, w)
        else:
            out.append((x, weight))
    return out


def escape_important(text):
    return text.replace("\\)", "\0\1").replace("\\(", "\0\2")


def unescape_important(text):
    return text.replace("\0\1", ")").replace("\0\2", "(")


def expand_directory_list(directories):
    """Expand directories to include subdirectories."""
    dirs = set(directories)
    for x in directories:
        for root, _, _ in os.walk(x, followlinks=True):
            dirs.add(root)
    return list(dirs)


def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
    """Load embedding from directory."""
    if isinstance(embedding_directory, str):
        embedding_directory = [embedding_directory]
    embedding_directory = expand_directory_list(embedding_directory)

    valid_file = None
    for embed_dir in embedding_directory:
        embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
        embed_dir = os.path.abspath(embed_dir)
        try:
            if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
                continue
        except Exception:
            continue
        
        if os.path.isfile(embed_path):
            valid_file = embed_path
        else:
            for ext in [".safetensors", ".pt", ".bin"]:
                if os.path.isfile(embed_path + ext):
                    valid_file = embed_path + ext
                    break
        if valid_file:
            break

    if not valid_file:
        return None

    try:
        if valid_file.lower().endswith(".safetensors"):
            import safetensors.torch
            embed = safetensors.torch.load_file(valid_file, device="cpu")
        else:
            embed = torch.load(valid_file, weights_only=True, map_location="cpu")
    except Exception:
        logging.warning(f"{traceback.format_exc()}\n\nerror loading embedding: {embedding_name}")
        return None

    if "string_to_param" in embed:
        return next(iter(embed["string_to_param"].values()))
    if isinstance(embed, list):
        out_list = [t.reshape(-1, t.shape[-1]) for x in embed for k, t in x.items() if t.shape[-1] == embedding_size]
        return torch.cat(out_list, dim=0) if out_list else None
    if embed_key and embed_key in embed:
        return embed[embed_key]
    return next(iter(embed.values()))


class SDTokenizer:
    """Stable Diffusion tokenizer."""
    def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True,
                 embedding_directory=None, embedding_size=768, embedding_key="clip_l",
                 tokenizer_class=CLIPTokenizerFast, has_start_token=True,
                 pad_to_max_length=True, min_length=None):
        self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path or "include/sd1_tokenizer/")
        self.max_length = max_length
        self.min_length = min_length
        self.pad_with_end = pad_with_end
        self.pad_to_max_length = pad_to_max_length
        self.embedding_directory = embedding_directory
        self.embedding_size = embedding_size
        self.embedding_key = embedding_key
        self.max_word_length = 8
        self.embedding_identifier = "embedding:"

        empty = self.tokenizer("")["input_ids"]
        self.tokens_start = 1 if has_start_token else 0
        self.start_token = empty[0] if has_start_token else None
        self.end_token = empty[1] if has_start_token else empty[0]
        self.inv_vocab = {v: k for k, v in self.tokenizer.get_vocab().items()}

    def _try_get_embedding(self, name):
        embed = load_embed(name, self.embedding_directory, self.embedding_size, self.embedding_key)
        if embed is None and (stripped := name.strip(",")) != name:
            embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
            return embed, name[len(stripped):]
        return embed, ""

    def tokenize_with_weights(self, text, return_word_ids=False):
        pad_token = self.end_token if self.pad_with_end else 0
        parsed = token_weights(escape_important(text), 1.0)
        tokens = []

        for segment, weight in parsed:
            for word in unescape_important(segment).replace("\n", " ").split():
                if word.startswith(self.embedding_identifier) and self.embedding_directory:
                    name = word[len(self.embedding_identifier):].strip("\n")
                    embed, leftover = self._try_get_embedding(name)
                    if embed is None:
                        logging.warning(f"embedding:{name} does not exist")
                    else:
                        tokens.append([(embed[i], weight) for i in range(embed.shape[0])] if len(embed.shape) > 1 else [(embed, weight)])
                        print("loading", name)
                    if leftover:
                        word = leftover
                    else:
                        continue
                tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])

        batched = []
        batch = [(self.start_token, 1.0, 0)] if self.start_token else []
        batched.append(batch)

        for i, t_group in enumerate(tokens):
            is_large = len(t_group) >= self.max_word_length
            while t_group:
                if len(t_group) + len(batch) > self.max_length - 1:
                    remaining = self.max_length - len(batch) - 1
                    if is_large:
                        batch.extend([(t, w, i + 1) for t, w in t_group[:remaining]])
                        batch.append((self.end_token, 1.0, 0))
                        t_group = t_group[remaining:]
                    else:
                        batch.append((self.end_token, 1.0, 0))
                        if self.pad_to_max_length:
                            batch.extend([(pad_token, 1.0, 0)] * remaining)
                    batch = [(self.start_token, 1.0, 0)] if self.start_token else []
                    batched.append(batch)
                else:
                    batch.extend([(t, w, i + 1) for t, w in t_group])
                    t_group = []

        batch.append((self.end_token, 1.0, 0))
        if self.pad_to_max_length:
            batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
        if self.min_length and len(batch) < self.min_length:
            batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))

        return batched if return_word_ids else [[(t, w) for t, w, _ in x] for x in batched]

    def untokenize(self, pairs):
        return [(a, self.inv_vocab[a[0]]) for a in pairs]


class SD1Tokenizer:
    """SD1 Tokenizer wrapper."""
    def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
        self.clip_name = clip_name
        self.clip = f"clip_{clip_name}"
        setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))

    def tokenize_with_weights(self, text, return_word_ids=False):
        return {self.clip_name: getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)}

    def untokenize(self, pairs):
        return getattr(self, self.clip).untokenize(pairs)