|
|
import torch |
|
|
import open_clip |
|
|
|
|
|
try: |
|
|
from sgm.modules import GeneralConditioner as CLIP_SDXL |
|
|
from sgm.modules.encoders.modules import FrozenOpenCLIPEmbedder2 |
|
|
from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedder2WithCustomWords |
|
|
except: |
|
|
print(f"[XL Vec] failed to load `sgm.modules`") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
def get_pooled(clip: CLIP_SDXL, text: str, layer='last', index=-1): |
|
|
|
|
|
|
|
|
mod = clip.embedders[1] |
|
|
if isinstance(mod, FrozenOpenCLIPEmbedder2WithCustomWords): |
|
|
mod = mod.wrapped |
|
|
|
|
|
assert isinstance(mod, FrozenOpenCLIPEmbedder2) |
|
|
|
|
|
tokens = open_clip.tokenize([text]).to(mod.device) |
|
|
|
|
|
x = mod.model.token_embedding(tokens) |
|
|
x = x + mod.model.positional_embedding |
|
|
x = x.permute(1, 0, 2) |
|
|
x = mod.text_transformer_forward(x, attn_mask=mod.model.attn_mask) |
|
|
|
|
|
o = x[layer] |
|
|
o = mod.model.ln_final(o) |
|
|
|
|
|
eot = tokens.argmax(dim=-1) |
|
|
p = torch.zeros_like(eot) |
|
|
if 0 <= index: |
|
|
p[0] = index |
|
|
else: |
|
|
p[0] = eot.item() + index + 1 |
|
|
|
|
|
real_index = p.item() |
|
|
assert 0 <= real_index < 77, f'index={index}, real_index={real_index}' |
|
|
|
|
|
pooled = ( |
|
|
o[torch.arange(o.shape[0]), p] |
|
|
@ mod.model.text_projection |
|
|
) |
|
|
|
|
|
return pooled, real_index |
|
|
|