DenseVLM / tools /generate_text_embeddings.py
lyhisme's picture
Upload 151 files
c02d17f verified
# Modified from [ViLD](https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/vild)
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import open_clip
import openseg_classes
import argparse
import json
def article(name):
return 'an' if name[0] in 'aeiou' else 'a'
def processed_name(name, rm_dot=False):
# _ for lvis
# / for obj365
res = name.replace('_', ' ').replace('/', ' or ').lower()
if rm_dot:
res = res.rstrip('.')
return res
single_template = [
'a photo of {article} {}.'
]
multiple_templates = [
'There is {article} {} in the scene.',
'There is the {} in the scene.',
'a photo of {article} {} in the scene.',
'a photo of the {} in the scene.',
'a photo of one {} in the scene.',
'itap of {article} {}.',
'itap of my {}.', # itap: I took a picture of
'itap of the {}.',
'a photo of {article} {}.',
'a photo of my {}.',
'a photo of the {}.',
'a photo of one {}.',
'a photo of many {}.',
'a good photo of {article} {}.',
'a good photo of the {}.',
'a bad photo of {article} {}.',
'a bad photo of the {}.',
'a photo of a nice {}.',
'a photo of the nice {}.',
'a photo of a cool {}.',
'a photo of the cool {}.',
'a photo of a weird {}.',
'a photo of the weird {}.',
'a photo of a small {}.',
'a photo of the small {}.',
'a photo of a large {}.',
'a photo of the large {}.',
'a photo of a clean {}.',
'a photo of the clean {}.',
'a photo of a dirty {}.',
'a photo of the dirty {}.',
'a bright photo of {article} {}.',
'a bright photo of the {}.',
'a dark photo of {article} {}.',
'a dark photo of the {}.',
'a photo of a hard to see {}.',
'a photo of the hard to see {}.',
'a low resolution photo of {article} {}.',
'a low resolution photo of the {}.',
'a cropped photo of {article} {}.',
'a cropped photo of the {}.',
'a close-up photo of {article} {}.',
'a close-up photo of the {}.',
'a jpeg corrupted photo of {article} {}.',
'a jpeg corrupted photo of the {}.',
'a blurry photo of {article} {}.',
'a blurry photo of the {}.',
'a pixelated photo of {article} {}.',
'a pixelated photo of the {}.',
'a black and white photo of the {}.',
'a black and white photo of {article} {}.',
'a plastic {}.',
'the plastic {}.',
'a toy {}.',
'the toy {}.',
'a plushie {}.',
'the plushie {}.',
'a cartoon {}.',
'the cartoon {}.',
'an embroidered {}.',
'the embroidered {}.',
'a painting of the {}.',
'a painting of a {}.',
]
def build_text_embedding_coco(categories, model):
templates = multiple_templates
with torch.no_grad():
zeroshot_weights = []
attn12_weights = []
for category in categories:
texts = [
template.format(processed_name(category, rm_dot=True), article=article(category))
for template in templates
]
texts = [
"This is " + text if text.startswith("a") or text.startswith("the") else text
for text in texts
]
texts = open_clip.tokenize(texts).cuda()
text_embeddings = model.encode_text(texts)
text_attnfeatures, _, _ = model.encode_text_endk(texts, stepk=12, normalize=True)
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
text_embedding = text_embeddings.mean(dim=0)
text_embedding /= text_embedding.norm()
text_attnfeatures = text_attnfeatures.mean(0)
text_attnfeatures = F.normalize(text_attnfeatures, dim=0)
attn12_weights.append(text_attnfeatures)
zeroshot_weights.append(text_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=0)
attn12_weights = torch.stack(attn12_weights, dim=0)
return zeroshot_weights, attn12_weights
def build_text_embedding_lvis_eng(categories, model, tokenizer):
templates = multiple_templates
with torch.no_grad():
all_text_embeddings = []
for category in tqdm(categories):
words = category.split(",")
word_embeddings = []
for word in words:
texts = [
template.format(
processed_name(word, rm_dot=True), article=article(word)
)
for template in templates
]
texts = [
"This is " + text if text.startswith("a") or text.startswith("the") else text
for text in texts
]
texts = tokenizer(texts).cuda()
text_embeddings = model.encode_text(texts)
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
word_embedding = text_embeddings.mean(dim=0)
word_embeddings.append(word_embedding)
word_embeddings = torch.stack(word_embeddings, dim=0)
category_embedding = word_embeddings.mean(dim=0)
category_embedding /= category_embedding.norm()
all_text_embeddings.append(category_embedding)
all_text_embeddings = torch.stack(all_text_embeddings, dim=0)
return all_text_embeddings
def build_text_embedding_lvis(categories, model, tokenizer):
templates = multiple_templates
with torch.no_grad():
all_text_embeddings = []
for category in tqdm(categories):
texts = [
template.format(
processed_name(category, rm_dot=True), article=article(category)
)
for template in templates
]
texts = [
"This is " + text if text.startswith("a") or text.startswith("the") else text
for text in texts
]
texts = tokenizer(texts).cuda() # tokenize
text_embeddings = model.encode_text(texts)
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
text_embedding = text_embeddings.mean(dim=0)
text_embedding /= text_embedding.norm()
text_embedding = text_embeddings.mean(dim=0)
all_text_embeddings.append(text_embedding)
all_text_embeddings = torch.stack(all_text_embeddings, dim=0)
return all_text_embeddings
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_version', default='EVA02-CLIP-B-16')
parser.add_argument('--out_path', default='metadata/COCO_STUFF_ADE20k_Thing204_STUFF112_clip_hand_craft_EVACLIP_ViTB16.npy')
parser.add_argument('--pretrained', default='eva')
parser.add_argument('--cache_dir', default='checkpoints/EVA02_CLIP_B_psz16_s8B.pt')
args = parser.parse_args()
model = open_clip.create_model(
args.model_version, pretrained=args.pretrained, cache_dir=args.cache_dir
)
tokenizer = open_clip.get_tokenizer(args.model_version)
model.cuda()
cat_data = openseg_classes.COCO_STUFF_ADE20k_Thing204_STUFF112
cat_names = [x['name'] for x in cat_data]
out_path = args.out_path
text_embeddings = build_text_embedding_lvis(cat_names, model, tokenizer)
np.save(out_path, text_embeddings.cpu().numpy())