File size: 1,685 Bytes
55e58d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import open_clip
from open_clip import tokenizer
import torch
import numpy as np
from evaluation.constants import MATTERPORT_LABELS, SCANNET_LABELS, SCANNETPP_LABELS, SCANNET18_LABELS, SCANNETPP84_LABELS, SCANNETPP84_IDS, ARKIT_LABELS, ARKIT_IDS
def load_clip():
print(f'[INFO] loading CLIP model...')
model, _, _ = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
model.cuda()
model.eval()
print(f'[INFO]', ' finish loading CLIP model...')
return model
def extract_text_feature(save_path, descriptions):
text_tokens = tokenizer.tokenize(descriptions).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_features = text_features.cpu().numpy()
text_features_dict = {}
for i, description in enumerate(descriptions):
text_features_dict[description] = text_features[i]
np.save(save_path, text_features_dict)
def get_text_feature(text):
text_tokens = tokenizer.tokenize(text).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
return text_features.cpu().numpy()
model = load_clip()
extract_text_feature('data/text_features/scannet.npy', SCANNET_LABELS)
extract_text_feature('data/text_features/scannetpp.npy', SCANNETPP_LABELS)
extract_text_feature('data/text_features/matterport3d.npy', MATTERPORT_LABELS)
extract_text_feature('data/text_features/scannet18.npy', SCANNET18_LABELS)
extract_text_feature('data/text_features/scannetpp84.npy', SCANNETPP84_LABELS)
extract_text_feature('data/text_features/arkit.npy', ARKIT_LABELS) |