| |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from tqdm import tqdm |
| import open_clip |
| import openseg_classes |
|
|
| import argparse |
| import json |
|
|
| def article(name): |
| return 'an' if name[0] in 'aeiou' else 'a' |
|
|
| def processed_name(name, rm_dot=False): |
| |
| |
| res = name.replace('_', ' ').replace('/', ' or ').lower() |
| if rm_dot: |
| res = res.rstrip('.') |
| return res |
|
|
|
|
| single_template = [ |
| 'a photo of {article} {}.' |
| ] |
|
|
| multiple_templates = [ |
| 'There is {article} {} in the scene.', |
| 'There is the {} in the scene.', |
| 'a photo of {article} {} in the scene.', |
| 'a photo of the {} in the scene.', |
| 'a photo of one {} in the scene.', |
|
|
|
|
| 'itap of {article} {}.', |
| 'itap of my {}.', |
| 'itap of the {}.', |
| 'a photo of {article} {}.', |
| 'a photo of my {}.', |
| 'a photo of the {}.', |
| 'a photo of one {}.', |
| 'a photo of many {}.', |
|
|
| 'a good photo of {article} {}.', |
| 'a good photo of the {}.', |
| 'a bad photo of {article} {}.', |
| 'a bad photo of the {}.', |
| 'a photo of a nice {}.', |
| 'a photo of the nice {}.', |
| 'a photo of a cool {}.', |
| 'a photo of the cool {}.', |
| 'a photo of a weird {}.', |
| 'a photo of the weird {}.', |
|
|
| 'a photo of a small {}.', |
| 'a photo of the small {}.', |
| 'a photo of a large {}.', |
| 'a photo of the large {}.', |
|
|
| 'a photo of a clean {}.', |
| 'a photo of the clean {}.', |
| 'a photo of a dirty {}.', |
| 'a photo of the dirty {}.', |
|
|
| 'a bright photo of {article} {}.', |
| 'a bright photo of the {}.', |
| 'a dark photo of {article} {}.', |
| 'a dark photo of the {}.', |
|
|
| 'a photo of a hard to see {}.', |
| 'a photo of the hard to see {}.', |
| 'a low resolution photo of {article} {}.', |
| 'a low resolution photo of the {}.', |
| 'a cropped photo of {article} {}.', |
| 'a cropped photo of the {}.', |
| 'a close-up photo of {article} {}.', |
| 'a close-up photo of the {}.', |
| 'a jpeg corrupted photo of {article} {}.', |
| 'a jpeg corrupted photo of the {}.', |
| 'a blurry photo of {article} {}.', |
| 'a blurry photo of the {}.', |
| 'a pixelated photo of {article} {}.', |
| 'a pixelated photo of the {}.', |
|
|
| 'a black and white photo of the {}.', |
| 'a black and white photo of {article} {}.', |
|
|
| 'a plastic {}.', |
| 'the plastic {}.', |
|
|
| 'a toy {}.', |
| 'the toy {}.', |
| 'a plushie {}.', |
| 'the plushie {}.', |
| 'a cartoon {}.', |
| 'the cartoon {}.', |
|
|
| 'an embroidered {}.', |
| 'the embroidered {}.', |
|
|
| 'a painting of the {}.', |
| 'a painting of a {}.', |
| ] |
|
|
| def build_text_embedding_coco(categories, model): |
| templates = multiple_templates |
| with torch.no_grad(): |
| zeroshot_weights = [] |
| attn12_weights = [] |
| for category in categories: |
| texts = [ |
| template.format(processed_name(category, rm_dot=True), article=article(category)) |
| for template in templates |
| ] |
| texts = [ |
| "This is " + text if text.startswith("a") or text.startswith("the") else text |
| for text in texts |
| ] |
|
|
| texts = open_clip.tokenize(texts).cuda() |
| text_embeddings = model.encode_text(texts) |
| text_attnfeatures, _, _ = model.encode_text_endk(texts, stepk=12, normalize=True) |
|
|
| text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) |
| text_embedding = text_embeddings.mean(dim=0) |
| text_embedding /= text_embedding.norm() |
|
|
| text_attnfeatures = text_attnfeatures.mean(0) |
| text_attnfeatures = F.normalize(text_attnfeatures, dim=0) |
| attn12_weights.append(text_attnfeatures) |
| zeroshot_weights.append(text_embedding) |
| zeroshot_weights = torch.stack(zeroshot_weights, dim=0) |
| attn12_weights = torch.stack(attn12_weights, dim=0) |
|
|
| return zeroshot_weights, attn12_weights |
|
|
|
|
| def build_text_embedding_lvis_eng(categories, model, tokenizer): |
| templates = multiple_templates |
|
|
| with torch.no_grad(): |
| all_text_embeddings = [] |
| |
| for category in tqdm(categories): |
| words = category.split(",") |
| word_embeddings = [] |
| for word in words: |
| texts = [ |
| template.format( |
| processed_name(word, rm_dot=True), article=article(word) |
| ) |
| for template in templates |
| ] |
| texts = [ |
| "This is " + text if text.startswith("a") or text.startswith("the") else text |
| for text in texts |
| ] |
| |
| texts = tokenizer(texts).cuda() |
|
|
| text_embeddings = model.encode_text(texts) |
| text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) |
| word_embedding = text_embeddings.mean(dim=0) |
| word_embeddings.append(word_embedding) |
| |
| word_embeddings = torch.stack(word_embeddings, dim=0) |
| category_embedding = word_embeddings.mean(dim=0) |
| category_embedding /= category_embedding.norm() |
| |
| all_text_embeddings.append(category_embedding) |
| |
| all_text_embeddings = torch.stack(all_text_embeddings, dim=0) |
|
|
| return all_text_embeddings |
|
|
| def build_text_embedding_lvis(categories, model, tokenizer): |
| templates = multiple_templates |
|
|
| with torch.no_grad(): |
| all_text_embeddings = [] |
| for category in tqdm(categories): |
| texts = [ |
| template.format( |
| processed_name(category, rm_dot=True), article=article(category) |
| ) |
| for template in templates |
| ] |
| texts = [ |
| "This is " + text if text.startswith("a") or text.startswith("the") else text |
| for text in texts |
| ] |
| texts = tokenizer(texts).cuda() |
|
|
| text_embeddings = model.encode_text(texts) |
| |
| text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) |
| text_embedding = text_embeddings.mean(dim=0) |
| text_embedding /= text_embedding.norm() |
|
|
| text_embedding = text_embeddings.mean(dim=0) |
|
|
| all_text_embeddings.append(text_embedding) |
| all_text_embeddings = torch.stack(all_text_embeddings, dim=0) |
|
|
| return all_text_embeddings |
|
|
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--model_version', default='EVA02-CLIP-B-16') |
| parser.add_argument('--out_path', default='metadata/COCO_STUFF_ADE20k_Thing204_STUFF112_clip_hand_craft_EVACLIP_ViTB16.npy') |
| parser.add_argument('--pretrained', default='eva') |
| parser.add_argument('--cache_dir', default='checkpoints/EVA02_CLIP_B_psz16_s8B.pt') |
|
|
| args = parser.parse_args() |
|
|
| model = open_clip.create_model( |
| args.model_version, pretrained=args.pretrained, cache_dir=args.cache_dir |
| ) |
| tokenizer = open_clip.get_tokenizer(args.model_version) |
| model.cuda() |
|
|
| cat_data = openseg_classes.COCO_STUFF_ADE20k_Thing204_STUFF112 |
|
|
| cat_names = [x['name'] for x in cat_data] |
|
|
| out_path = args.out_path |
| text_embeddings = build_text_embedding_lvis(cat_names, model, tokenizer) |
| np.save(out_path, text_embeddings.cpu().numpy()) |
|
|