dikdimon's picture
Upload sd-webui-xl_vec using SD-Hub
8c0bf45 verified
import torch
import open_clip
try:
from sgm.modules import GeneralConditioner as CLIP_SDXL
from sgm.modules.encoders.modules import FrozenOpenCLIPEmbedder2
from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedder2WithCustomWords
except:
print(f"[XL Vec] failed to load `sgm.modules`")
raise
def get_pooled(clip: CLIP_SDXL, text: str, layer='last', index=-1):
# cf. sgm/modules/encoders/modules.py:FrozenOpenCLIPEmbedder2
mod = clip.embedders[1]
if isinstance(mod, FrozenOpenCLIPEmbedder2WithCustomWords):
mod = mod.wrapped
assert isinstance(mod, FrozenOpenCLIPEmbedder2)
tokens = open_clip.tokenize([text]).to(mod.device)
x = mod.model.token_embedding(tokens) # [batch_size, n_ctx, d_model]
x = x + mod.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = mod.text_transformer_forward(x, attn_mask=mod.model.attn_mask)
o = x[layer]
o = mod.model.ln_final(o)
eot = tokens.argmax(dim=-1)
p = torch.zeros_like(eot)
if 0 <= index:
p[0] = index
else:
p[0] = eot.item() + index + 1
real_index = p.item()
assert 0 <= real_index < 77, f'index={index}, real_index={real_index}'
pooled = (
o[torch.arange(o.shape[0]), p]
@ mod.model.text_projection
)
return pooled, real_index