File size: 1,378 Bytes
3dabe4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import open_clip

try:
    from sgm.modules import GeneralConditioner as CLIP_SDXL
    from sgm.modules.encoders.modules import FrozenOpenCLIPEmbedder2
    from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedder2WithCustomWords
except:
    print(f"[XL Vec] failed to load `sgm.modules`")
    raise



def get_pooled(clip: CLIP_SDXL, text: str, layer='last', index=-1):
    # cf. sgm/modules/encoders/modules.py:FrozenOpenCLIPEmbedder2
    
    mod = clip.embedders[1]
    if isinstance(mod, FrozenOpenCLIPEmbedder2WithCustomWords):
        mod = mod.wrapped
    
    assert isinstance(mod, FrozenOpenCLIPEmbedder2)
    
    tokens = open_clip.tokenize([text]).to(mod.device)
    
    x = mod.model.token_embedding(tokens)  # [batch_size, n_ctx, d_model]
    x = x + mod.model.positional_embedding
    x = x.permute(1, 0, 2)  # NLD -> LND
    x = mod.text_transformer_forward(x, attn_mask=mod.model.attn_mask)

    o = x[layer]
    o = mod.model.ln_final(o)
    
    eot = tokens.argmax(dim=-1)
    p = torch.zeros_like(eot)
    if 0 <= index:
        p[0] = index
    else:
        p[0] = eot.item() + index + 1
    
    real_index = p.item()
    assert 0 <= real_index < 77, f'index={index}, real_index={real_index}'
        
    pooled = (
        o[torch.arange(o.shape[0]), p]
        @ mod.model.text_projection
    )
    
    return pooled, real_index