File size: 3,551 Bytes
4aabce3
 
 
 
 
 
 
 
 
 
 
 
 
a625e96
 
4aabce3
 
 
 
 
 
 
 
 
a625e96
4aabce3
 
 
 
 
 
 
 
 
 
 
87b5061
4aabce3
 
 
 
87b5061
4aabce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87b5061
4aabce3
 
 
 
a625e96
4aabce3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
import open_clip
import torch
import os
from config import *
from tqdm import tqdm

def save_embeddings(model, tokenizer, device):
    with open(text_captions_dir, "r") as f:
        last_img_name = ""; counter = 0
        for l in tqdm(f.readlines(), desc=f"Progress"):
            img_name, caption = l.strip().split(".jpg,")
            caption = caption.strip()
            if caption[-1]==".": caption = caption.rstrip('.')
            if len(last_img_name) == 0 or last_img_name != img_name: last_img_name = img_name; counter = 0
            else: counter += 1
            img_name += f"_{counter}"
            if f"{img_name}.pt" not in os.listdir(embedding_dir):
                embedding = get_text_embedding(model, tokenizer, caption, device)
                torch.save(embedding.cpu(), f"./{embedding_dir}/{img_name}.pt")

def save_null_embedding(model, tokenizer, device):
    embedding = get_text_embedding(model, tokenizer, "", device)
    torch.save(embedding.cpu(), null_embedding_dir)
    print("Saved null embedding")

def save_val_embeddings(model, tokenizer, device):
    os.makedirs(f"{unet_val_embeddings_dir}", exist_ok=True)
    val_prompts = [
        "A brown dog running through green grass",          # Anchor test
        "A person in a red shirt sitting on a bench",       # Spatial Logic test
        "A child playing with a colorful ball on a beach",  # Complex Scene test
        "A giant transparent bubble over a city street"     # Stress test
    ]
    for i, caption in tqdm(enumerate(val_prompts), desc=f"Progress"):
        embedding = get_text_embedding(model, tokenizer, caption.lower(), device)
        torch.save(embedding.cpu(), f"{unet_val_embeddings_dir}/val_embedding_{i}.pt")

def get_embedding_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cpu"
    model, _, preprocess = open_clip.create_model_and_transforms(embedding_model, pretrained=embedding_pretrained, device=device)
    model.eval()
    return model, device

def get_text_embedding(model, tokenizer, caption, device):
    tokens = tokenizer(caption).to(device) # (1, 77)
    with torch.no_grad():
        x = model.token_embedding(tokens)
        x = x + model.positional_embedding
        x = x.permute(1, 0, 2) # (L, N, D) for transformer
        seq_len = x.shape[0]
        mask = torch.empty(seq_len, seq_len, device=x.device)
        mask.fill_(float("-inf"))
        mask.triu_(1)
        x = model.transformer(x, attn_mask=mask) # Applying mask
        x = x.permute(1, 0, 2) # back to (N, L, D)
        per_token_contextual = model.ln_final(x) # (B, T, D) = (1, 77, 1024)
        return per_token_contextual.squeeze(0) # (77, 1024)

if __name__ == "__main__":
    model, device = get_embedding_model()
    model.transformer.batch_first = False
    for p in model.parameters():
        p.requires_grad = False
    tokenizer = open_clip.get_tokenizer(embedding_model)
    os.makedirs(f"./{embedding_dir}", exist_ok=True)
    save_embeddings(model, tokenizer, device)
    save_null_embedding(model, tokenizer, device)
    save_val_embeddings(model, tokenizer, device)

# Check if embeddings are equal
# emb1 = torch.load("./backend/core/null_embedding.pt", map_location="cpu", weights_only=True)
# emb2 = torch.load("./backend/core/embeddings_77_1024/null_embedding.pt", map_location="cpu", weights_only=True)
# print("Max diff:", (emb1 - emb2).abs().max().item())
# print("Are equal:", torch.allclose(emb2, emb1, atol=1e-5))