import sys, os sys.path.insert(0, os.path.dirname(__file__)) import open_clip import torch import os from config import * from tqdm import tqdm def save_embeddings(model, tokenizer, device): with open(text_captions_dir, "r") as f: last_img_name = ""; counter = 0 for l in tqdm(f.readlines(), desc=f"Progress"): img_name, caption = l.strip().split(".jpg,") caption = caption.strip() if caption[-1]==".": caption = caption.rstrip('.') if len(last_img_name) == 0 or last_img_name != img_name: last_img_name = img_name; counter = 0 else: counter += 1 img_name += f"_{counter}" if f"{img_name}.pt" not in os.listdir(embedding_dir): embedding = get_text_embedding(model, tokenizer, caption, device) torch.save(embedding.cpu(), f"./{embedding_dir}/{img_name}.pt") def save_null_embedding(model, tokenizer, device): embedding = get_text_embedding(model, tokenizer, "", device) torch.save(embedding.cpu(), null_embedding_dir) print("Saved null embedding") def save_val_embeddings(model, tokenizer, device): os.makedirs(f"{unet_val_embeddings_dir}", exist_ok=True) val_prompts = [ "A brown dog running through green grass", # Anchor test "A person in a red shirt sitting on a bench", # Spatial Logic test "A child playing with a colorful ball on a beach", # Complex Scene test "A giant transparent bubble over a city street" # Stress test ] for i, caption in tqdm(enumerate(val_prompts), desc=f"Progress"): embedding = get_text_embedding(model, tokenizer, caption.lower(), device) torch.save(embedding.cpu(), f"{unet_val_embeddings_dir}/val_embedding_{i}.pt") def get_embedding_model(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = "cpu" model, _, preprocess = open_clip.create_model_and_transforms(embedding_model, pretrained=embedding_pretrained, device=device) model.eval() return model, device def get_text_embedding(model, tokenizer, caption, device): tokens = tokenizer(caption).to(device) # (1, 77) with torch.no_grad(): x = model.token_embedding(tokens) x = x + model.positional_embedding x = x.permute(1, 0, 2) # (L, N, D) for transformer seq_len = x.shape[0] mask = torch.empty(seq_len, seq_len, device=x.device) mask.fill_(float("-inf")) mask.triu_(1) x = model.transformer(x, attn_mask=mask) # Applying mask x = x.permute(1, 0, 2) # back to (N, L, D) per_token_contextual = model.ln_final(x) # (B, T, D) = (1, 77, 1024) return per_token_contextual.squeeze(0) # (77, 1024) if __name__ == "__main__": model, device = get_embedding_model() model.transformer.batch_first = False for p in model.parameters(): p.requires_grad = False tokenizer = open_clip.get_tokenizer(embedding_model) os.makedirs(f"./{embedding_dir}", exist_ok=True) save_embeddings(model, tokenizer, device) save_null_embedding(model, tokenizer, device) save_val_embeddings(model, tokenizer, device) # Check if embeddings are equal # emb1 = torch.load("./backend/core/null_embedding.pt", map_location="cpu", weights_only=True) # emb2 = torch.load("./backend/core/embeddings_77_1024/null_embedding.pt", map_location="cpu", weights_only=True) # print("Max diff:", (emb1 - emb2).abs().max().item()) # print("Are equal:", torch.allclose(emb2, emb1, atol=1e-5))