|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from tqdm import tqdm |
|
|
import open_clip |
|
|
import openseg_classes |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
|
|
|
def article(name): |
|
|
return 'an' if name[0] in 'aeiou' else 'a' |
|
|
|
|
|
def processed_name(name, rm_dot=False): |
|
|
|
|
|
|
|
|
res = name.replace('_', ' ').replace('/', ' or ').lower() |
|
|
if rm_dot: |
|
|
res = res.rstrip('.') |
|
|
return res |
|
|
|
|
|
|
|
|
single_template = [ |
|
|
'a photo of {article} {}.' |
|
|
] |
|
|
|
|
|
multiple_templates = [ |
|
|
'There is {article} {} in the scene.', |
|
|
'There is the {} in the scene.', |
|
|
'a photo of {article} {} in the scene.', |
|
|
'a photo of the {} in the scene.', |
|
|
'a photo of one {} in the scene.', |
|
|
|
|
|
|
|
|
'itap of {article} {}.', |
|
|
'itap of my {}.', |
|
|
'itap of the {}.', |
|
|
'a photo of {article} {}.', |
|
|
'a photo of my {}.', |
|
|
'a photo of the {}.', |
|
|
'a photo of one {}.', |
|
|
'a photo of many {}.', |
|
|
|
|
|
'a good photo of {article} {}.', |
|
|
'a good photo of the {}.', |
|
|
'a bad photo of {article} {}.', |
|
|
'a bad photo of the {}.', |
|
|
'a photo of a nice {}.', |
|
|
'a photo of the nice {}.', |
|
|
'a photo of a cool {}.', |
|
|
'a photo of the cool {}.', |
|
|
'a photo of a weird {}.', |
|
|
'a photo of the weird {}.', |
|
|
|
|
|
'a photo of a small {}.', |
|
|
'a photo of the small {}.', |
|
|
'a photo of a large {}.', |
|
|
'a photo of the large {}.', |
|
|
|
|
|
'a photo of a clean {}.', |
|
|
'a photo of the clean {}.', |
|
|
'a photo of a dirty {}.', |
|
|
'a photo of the dirty {}.', |
|
|
|
|
|
'a bright photo of {article} {}.', |
|
|
'a bright photo of the {}.', |
|
|
'a dark photo of {article} {}.', |
|
|
'a dark photo of the {}.', |
|
|
|
|
|
'a photo of a hard to see {}.', |
|
|
'a photo of the hard to see {}.', |
|
|
'a low resolution photo of {article} {}.', |
|
|
'a low resolution photo of the {}.', |
|
|
'a cropped photo of {article} {}.', |
|
|
'a cropped photo of the {}.', |
|
|
'a close-up photo of {article} {}.', |
|
|
'a close-up photo of the {}.', |
|
|
'a jpeg corrupted photo of {article} {}.', |
|
|
'a jpeg corrupted photo of the {}.', |
|
|
'a blurry photo of {article} {}.', |
|
|
'a blurry photo of the {}.', |
|
|
'a pixelated photo of {article} {}.', |
|
|
'a pixelated photo of the {}.', |
|
|
|
|
|
'a black and white photo of the {}.', |
|
|
'a black and white photo of {article} {}.', |
|
|
|
|
|
'a plastic {}.', |
|
|
'the plastic {}.', |
|
|
|
|
|
'a toy {}.', |
|
|
'the toy {}.', |
|
|
'a plushie {}.', |
|
|
'the plushie {}.', |
|
|
'a cartoon {}.', |
|
|
'the cartoon {}.', |
|
|
|
|
|
'an embroidered {}.', |
|
|
'the embroidered {}.', |
|
|
|
|
|
'a painting of the {}.', |
|
|
'a painting of a {}.', |
|
|
] |
|
|
|
|
|
def build_text_embedding_coco(categories, model): |
|
|
templates = multiple_templates |
|
|
with torch.no_grad(): |
|
|
zeroshot_weights = [] |
|
|
attn12_weights = [] |
|
|
for category in categories: |
|
|
texts = [ |
|
|
template.format(processed_name(category, rm_dot=True), article=article(category)) |
|
|
for template in templates |
|
|
] |
|
|
texts = [ |
|
|
"This is " + text if text.startswith("a") or text.startswith("the") else text |
|
|
for text in texts |
|
|
] |
|
|
|
|
|
texts = open_clip.tokenize(texts).cuda() |
|
|
text_embeddings = model.encode_text(texts) |
|
|
text_attnfeatures, _, _ = model.encode_text_endk(texts, stepk=12, normalize=True) |
|
|
|
|
|
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) |
|
|
text_embedding = text_embeddings.mean(dim=0) |
|
|
text_embedding /= text_embedding.norm() |
|
|
|
|
|
text_attnfeatures = text_attnfeatures.mean(0) |
|
|
text_attnfeatures = F.normalize(text_attnfeatures, dim=0) |
|
|
attn12_weights.append(text_attnfeatures) |
|
|
zeroshot_weights.append(text_embedding) |
|
|
zeroshot_weights = torch.stack(zeroshot_weights, dim=0) |
|
|
attn12_weights = torch.stack(attn12_weights, dim=0) |
|
|
|
|
|
return zeroshot_weights, attn12_weights |
|
|
|
|
|
|
|
|
def build_text_embedding_lvis_eng(categories, model, tokenizer): |
|
|
templates = multiple_templates |
|
|
|
|
|
with torch.no_grad(): |
|
|
all_text_embeddings = [] |
|
|
|
|
|
for category in tqdm(categories): |
|
|
words = category.split(",") |
|
|
word_embeddings = [] |
|
|
for word in words: |
|
|
texts = [ |
|
|
template.format( |
|
|
processed_name(word, rm_dot=True), article=article(word) |
|
|
) |
|
|
for template in templates |
|
|
] |
|
|
texts = [ |
|
|
"This is " + text if text.startswith("a") or text.startswith("the") else text |
|
|
for text in texts |
|
|
] |
|
|
|
|
|
texts = tokenizer(texts).cuda() |
|
|
|
|
|
text_embeddings = model.encode_text(texts) |
|
|
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) |
|
|
word_embedding = text_embeddings.mean(dim=0) |
|
|
word_embeddings.append(word_embedding) |
|
|
|
|
|
word_embeddings = torch.stack(word_embeddings, dim=0) |
|
|
category_embedding = word_embeddings.mean(dim=0) |
|
|
category_embedding /= category_embedding.norm() |
|
|
|
|
|
all_text_embeddings.append(category_embedding) |
|
|
|
|
|
all_text_embeddings = torch.stack(all_text_embeddings, dim=0) |
|
|
|
|
|
return all_text_embeddings |
|
|
|
|
|
def build_text_embedding_lvis(categories, model, tokenizer): |
|
|
templates = multiple_templates |
|
|
|
|
|
with torch.no_grad(): |
|
|
all_text_embeddings = [] |
|
|
for category in tqdm(categories): |
|
|
texts = [ |
|
|
template.format( |
|
|
processed_name(category, rm_dot=True), article=article(category) |
|
|
) |
|
|
for template in templates |
|
|
] |
|
|
texts = [ |
|
|
"This is " + text if text.startswith("a") or text.startswith("the") else text |
|
|
for text in texts |
|
|
] |
|
|
texts = tokenizer(texts).cuda() |
|
|
|
|
|
text_embeddings = model.encode_text(texts) |
|
|
|
|
|
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) |
|
|
text_embedding = text_embeddings.mean(dim=0) |
|
|
text_embedding /= text_embedding.norm() |
|
|
|
|
|
text_embedding = text_embeddings.mean(dim=0) |
|
|
|
|
|
all_text_embeddings.append(text_embedding) |
|
|
all_text_embeddings = torch.stack(all_text_embeddings, dim=0) |
|
|
|
|
|
return all_text_embeddings |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--model_version', default='EVA02-CLIP-B-16') |
|
|
parser.add_argument('--out_path', default='metadata/COCO_STUFF_ADE20k_Thing204_STUFF112_clip_hand_craft_EVACLIP_ViTB16.npy') |
|
|
parser.add_argument('--pretrained', default='eva') |
|
|
parser.add_argument('--cache_dir', default='checkpoints/EVA02_CLIP_B_psz16_s8B.pt') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
model = open_clip.create_model( |
|
|
args.model_version, pretrained=args.pretrained, cache_dir=args.cache_dir |
|
|
) |
|
|
tokenizer = open_clip.get_tokenizer(args.model_version) |
|
|
model.cuda() |
|
|
|
|
|
cat_data = openseg_classes.COCO_STUFF_ADE20k_Thing204_STUFF112 |
|
|
|
|
|
cat_names = [x['name'] for x in cat_data] |
|
|
|
|
|
out_path = args.out_path |
|
|
text_embeddings = build_text_embedding_lvis(cat_names, model, tokenizer) |
|
|
np.save(out_path, text_embeddings.cpu().numpy()) |
|
|
|