Spaces:
Sleeping
Sleeping
| import sys, os | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| import open_clip | |
| import torch | |
| import os | |
| from config import * | |
| from tqdm import tqdm | |
| def save_embeddings(model, tokenizer, device): | |
| with open(text_captions_dir, "r") as f: | |
| last_img_name = ""; counter = 0 | |
| for l in tqdm(f.readlines(), desc=f"Progress"): | |
| img_name, caption = l.strip().split(".jpg,") | |
| caption = caption.strip() | |
| if caption[-1]==".": caption = caption.rstrip('.') | |
| if len(last_img_name) == 0 or last_img_name != img_name: last_img_name = img_name; counter = 0 | |
| else: counter += 1 | |
| img_name += f"_{counter}" | |
| if f"{img_name}.pt" not in os.listdir(embedding_dir): | |
| embedding = get_text_embedding(model, tokenizer, caption, device) | |
| torch.save(embedding.cpu(), f"./{embedding_dir}/{img_name}.pt") | |
| def save_null_embedding(model, tokenizer, device): | |
| embedding = get_text_embedding(model, tokenizer, "", device) | |
| torch.save(embedding.cpu(), null_embedding_dir) | |
| print("Saved null embedding") | |
| def save_val_embeddings(model, tokenizer, device): | |
| os.makedirs(f"{unet_val_embeddings_dir}", exist_ok=True) | |
| val_prompts = [ | |
| "A brown dog running through green grass", # Anchor test | |
| "A person in a red shirt sitting on a bench", # Spatial Logic test | |
| "A child playing with a colorful ball on a beach", # Complex Scene test | |
| "A giant transparent bubble over a city street" # Stress test | |
| ] | |
| for i, caption in tqdm(enumerate(val_prompts), desc=f"Progress"): | |
| embedding = get_text_embedding(model, tokenizer, caption.lower(), device) | |
| torch.save(embedding.cpu(), f"{unet_val_embeddings_dir}/val_embedding_{i}.pt") | |
| def get_embedding_model(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device = "cpu" | |
| model, _, preprocess = open_clip.create_model_and_transforms(embedding_model, pretrained=embedding_pretrained, device=device) | |
| model.eval() | |
| return model, device | |
| def get_text_embedding(model, tokenizer, caption, device): | |
| tokens = tokenizer(caption).to(device) # (1, 77) | |
| with torch.no_grad(): | |
| x = model.token_embedding(tokens) | |
| x = x + model.positional_embedding | |
| x = x.permute(1, 0, 2) # (L, N, D) for transformer | |
| seq_len = x.shape[0] | |
| mask = torch.empty(seq_len, seq_len, device=x.device) | |
| mask.fill_(float("-inf")) | |
| mask.triu_(1) | |
| x = model.transformer(x, attn_mask=mask) # Applying mask | |
| x = x.permute(1, 0, 2) # back to (N, L, D) | |
| per_token_contextual = model.ln_final(x) # (B, T, D) = (1, 77, 1024) | |
| return per_token_contextual.squeeze(0) # (77, 1024) | |
| if __name__ == "__main__": | |
| model, device = get_embedding_model() | |
| model.transformer.batch_first = False | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| tokenizer = open_clip.get_tokenizer(embedding_model) | |
| os.makedirs(f"./{embedding_dir}", exist_ok=True) | |
| save_embeddings(model, tokenizer, device) | |
| save_null_embedding(model, tokenizer, device) | |
| save_val_embeddings(model, tokenizer, device) | |
| # Check if embeddings are equal | |
| # emb1 = torch.load("./backend/core/null_embedding.pt", map_location="cpu", weights_only=True) | |
| # emb2 = torch.load("./backend/core/embeddings_77_1024/null_embedding.pt", map_location="cpu", weights_only=True) | |
| # print("Max diff:", (emb1 - emb2).abs().max().item()) | |
| # print("Are equal:", torch.allclose(emb2, emb1, atol=1e-5)) |