flickr8k-backend / core /text_embeddings.py
Rohan3's picture
Updated: VAE, UNet, config, text embeddings, model and main
a625e96
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
import open_clip
import torch
import os
from config import *
from tqdm import tqdm
def save_embeddings(model, tokenizer, device):
with open(text_captions_dir, "r") as f:
last_img_name = ""; counter = 0
for l in tqdm(f.readlines(), desc=f"Progress"):
img_name, caption = l.strip().split(".jpg,")
caption = caption.strip()
if caption[-1]==".": caption = caption.rstrip('.')
if len(last_img_name) == 0 or last_img_name != img_name: last_img_name = img_name; counter = 0
else: counter += 1
img_name += f"_{counter}"
if f"{img_name}.pt" not in os.listdir(embedding_dir):
embedding = get_text_embedding(model, tokenizer, caption, device)
torch.save(embedding.cpu(), f"./{embedding_dir}/{img_name}.pt")
def save_null_embedding(model, tokenizer, device):
embedding = get_text_embedding(model, tokenizer, "", device)
torch.save(embedding.cpu(), null_embedding_dir)
print("Saved null embedding")
def save_val_embeddings(model, tokenizer, device):
os.makedirs(f"{unet_val_embeddings_dir}", exist_ok=True)
val_prompts = [
"A brown dog running through green grass", # Anchor test
"A person in a red shirt sitting on a bench", # Spatial Logic test
"A child playing with a colorful ball on a beach", # Complex Scene test
"A giant transparent bubble over a city street" # Stress test
]
for i, caption in tqdm(enumerate(val_prompts), desc=f"Progress"):
embedding = get_text_embedding(model, tokenizer, caption.lower(), device)
torch.save(embedding.cpu(), f"{unet_val_embeddings_dir}/val_embedding_{i}.pt")
def get_embedding_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
model, _, preprocess = open_clip.create_model_and_transforms(embedding_model, pretrained=embedding_pretrained, device=device)
model.eval()
return model, device
def get_text_embedding(model, tokenizer, caption, device):
tokens = tokenizer(caption).to(device) # (1, 77)
with torch.no_grad():
x = model.token_embedding(tokens)
x = x + model.positional_embedding
x = x.permute(1, 0, 2) # (L, N, D) for transformer
seq_len = x.shape[0]
mask = torch.empty(seq_len, seq_len, device=x.device)
mask.fill_(float("-inf"))
mask.triu_(1)
x = model.transformer(x, attn_mask=mask) # Applying mask
x = x.permute(1, 0, 2) # back to (N, L, D)
per_token_contextual = model.ln_final(x) # (B, T, D) = (1, 77, 1024)
return per_token_contextual.squeeze(0) # (77, 1024)
if __name__ == "__main__":
model, device = get_embedding_model()
model.transformer.batch_first = False
for p in model.parameters():
p.requires_grad = False
tokenizer = open_clip.get_tokenizer(embedding_model)
os.makedirs(f"./{embedding_dir}", exist_ok=True)
save_embeddings(model, tokenizer, device)
save_null_embedding(model, tokenizer, device)
save_val_embeddings(model, tokenizer, device)
# Check if embeddings are equal
# emb1 = torch.load("./backend/core/null_embedding.pt", map_location="cpu", weights_only=True)
# emb2 = torch.load("./backend/core/embeddings_77_1024/null_embedding.pt", map_location="cpu", weights_only=True)
# print("Max diff:", (emb1 - emb2).abs().max().item())
# print("Are equal:", torch.allclose(emb2, emb1, atol=1e-5))