dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
raw
history blame
7.65 kB
from dataclasses import dataclass
from itertools import product
import re
from typing import Union, List, Tuple
import numpy as np
import open_clip
from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase as CLIP
try:
from sgm.modules import GeneralConditioner as CLIP_SDXL
except:
print("[Cutoff] failed to load `sgm.modules.GeneralConditioner`")
CLIP_SDXL = int
from modules import prompt_parser, shared
from scripts.cutofflib.utils import log
class ClipWrapper:
def __init__(self, te: Union[CLIP,CLIP_SDXL]):
self.te = te
self.v1 = hasattr(te.wrapped, 'tokenizer')
self.sdxl = hasattr(te, 'embedders')
self.t = (
te.wrapped.tokenizer if self.v1
else open_clip.tokenizer._tokenizer
)
def token_to_id(self, token: str) -> int:
if self.v1:
return self.t._convert_token_to_id(token) # type: ignore
else:
return self.t.encoder[token]
def id_to_token(self, id: int) -> str:
if self.v1:
return self.t.convert_ids_to_tokens(id) # type: ignore
else:
return self.t.decoder[id]
def ids_to_tokens(self, ids: List[int]) -> List[str]:
if self.v1:
return self.t.convert_ids_to_tokens(ids) # type: ignore
else:
return [self.t.decoder[id] for id in ids]
def token(self, token: Union[int,str]):
if isinstance(token, int):
return Token(token, self.id_to_token(token))
else:
return Token(self.token_to_id(token), token)
@property
def id_start(self):
if self.sdxl:
return self.te.embedders[0].id_start
else:
return self.te.id_start
@property
def id_end(self):
if self.sdxl:
return self.te.embedders[0].id_end
else:
return self.te.id_end
@property
def hijack(self):
if self.sdxl:
return self.te.embedders[0].hijack
else:
return self.te.hijack
@dataclass
class Token:
id: int
token: str
class CutoffPrompt:
@staticmethod
def _cutoff(prompt: str, clip: CLIP, tokens: List[str], padding: str):
def token_count(text: str):
te = ClipWrapper(clip)
tt = token_to_block(clip, text)
# tt[0] == te.id_start (<|startoftext|>)
for index, (t, _) in enumerate(tt):
if t.id == te.id_end: # <|endoftext|>
return index - 1
return 0 # must not happen...
re_targets = [ re.compile(r'\b' + re.escape(x) + r'\b') for x in tokens ]
replacer = [ ' ' + ' '.join([padding] * token_count(x)) + ' ' for x in tokens ]
rows: List[Tuple[str,str]] = []
for block in prompt.split(','):
b0 = block
for r, p in zip(re_targets, replacer):
block = r.sub(p, block)
b1 = block
rows.append((b0, b1))
return rows
def __init__(self, prompt: str, clip: CLIP, tokens: List[str], padding: str):
self.prompt = prompt
rows = CutoffPrompt._cutoff(prompt, clip, tokens, padding)
self.base = np.array([x[0] for x in rows])
self.cut = np.array([x[1] for x in rows])
self.sw = np.array([False] * len(rows))
@property
def block_count(self):
return self.base.shape[0]
def switch(self, block_index: int, to: Union[bool,None] = None):
if to is None:
to = not self.sw[block_index]
self.sw[block_index] = to
return to
def text(self, sw=None):
if sw is None:
sw = self.sw
blocks = np.where(sw, self.cut, self.base)
return ','.join(blocks)
def active_blocks(self) -> np.ndarray:
indices, = (self.base != self.cut).nonzero()
return indices
def generate(self):
indices = self.active_blocks()
for diff_sw in product([False, True], repeat=indices.shape[0]):
sw = np.full_like(self.sw, False)
sw[indices] = diff_sw
yield diff_sw, self.text(sw)
def generate_prompts(
clip: CLIP,
prompt: str,
targets: List[str],
padding: Union[str,int,Token],
) -> CutoffPrompt:
te = ClipWrapper(clip)
if not isinstance(padding, Token):
o_pad = padding
padding = te.token(padding)
if padding.id == te.id_end:
raise ValueError(f'`{o_pad}` is not a valid token.')
result = CutoffPrompt(prompt, clip, targets, padding.token.replace('</w>', ''))
log(f'[Cutoff] replace: {", ".join(targets)}')
log(f'[Cutoff] to: {padding.token} ({padding.id})')
log(f'[Cutoff] original: {prompt}')
for i, (_, pp) in enumerate(result.generate()):
log(f'[Cutoff] #{i}: {pp}')
return result
def token_to_block(clip: CLIP, prompt: str):
te = ClipWrapper(clip)
# cf. sd_hijack_clip.py
parsed = prompt_parser.parse_prompt_attention(prompt)
tokenized: List[List[int]] = clip.tokenize([text for text, _ in parsed])
CHUNK_LENGTH = 75
id_start = te.token(te.id_start) # type: ignore
id_end = te.token(te.id_end) # type: ignore
comma = te.token(',</w>')
last_comma = -1
current_block = 0
current_tokens: List[Tuple[Token,int]] = []
result: List[Tuple[Token,int]] = []
def next_chunk():
nonlocal current_tokens, last_comma
to_add = CHUNK_LENGTH - len(current_tokens)
if 0 < to_add:
current_tokens += [(id_end, -1)] * to_add
current_tokens = [(id_start, -1)] + current_tokens + [(id_end, -1)]
last_comma = -1
result.extend(current_tokens)
current_tokens = []
for tokens, (text, weight) in zip(tokenized, parsed):
if text == 'BREAK' and weight == -1:
next_chunk()
continue
p = 0
while p < len(tokens):
token = tokens[p]
if token == comma.id:
last_comma = len(current_tokens)
current_block += 1
elif (
shared.opts.comma_padding_backtrack != 0
and len(current_tokens) == CHUNK_LENGTH
and last_comma != -1
and len(current_tokens) - last_comma <= shared.opts.comma_padding_backtrack
):
break_location = last_comma + 1
reloc_tokens = current_tokens[break_location:]
current_tokens = current_tokens[:break_location]
next_chunk()
current_tokens = reloc_tokens
if len(current_tokens) == CHUNK_LENGTH:
next_chunk()
embedding, embedding_length_in_tokens = te.hijack.embedding_db.find_embedding_at_position(tokens, p)
if embedding is None:
if token == comma.id:
current_tokens.append((te.token(token), -1))
else:
current_tokens.append((te.token(token), current_block))
p += 1
continue
emb_len = int(embedding.vec.shape[0])
if len(current_tokens) + emb_len > CHUNK_LENGTH:
next_chunk()
current_tokens += [(te.token(0), current_block)] * emb_len
p += embedding_length_in_tokens
if len(current_tokens) > 0:
next_chunk()
return result