File size: 1,685 Bytes
55e58d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import open_clip
from open_clip import tokenizer
import torch
import numpy as np
from evaluation.constants import MATTERPORT_LABELS, SCANNET_LABELS, SCANNETPP_LABELS, SCANNET18_LABELS, SCANNETPP84_LABELS, SCANNETPP84_IDS, ARKIT_LABELS, ARKIT_IDS

def load_clip():
    print(f'[INFO] loading CLIP model...')
    model, _, _ = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
    model.cuda()
    model.eval()
    print(f'[INFO]', ' finish loading CLIP model...')
    return model

def extract_text_feature(save_path, descriptions):
    text_tokens = tokenizer.tokenize(descriptions).cuda()
    with torch.no_grad():
        text_features = model.encode_text(text_tokens).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_features = text_features.cpu().numpy()

    text_features_dict = {}
    for i, description in enumerate(descriptions):
        text_features_dict[description] = text_features[i]

    np.save(save_path, text_features_dict)

def get_text_feature(text):
    text_tokens = tokenizer.tokenize(text).cuda()
    with torch.no_grad():
        text_features = model.encode_text(text_tokens).float()
    return text_features.cpu().numpy()

model = load_clip()
extract_text_feature('data/text_features/scannet.npy', SCANNET_LABELS)
extract_text_feature('data/text_features/scannetpp.npy', SCANNETPP_LABELS)
extract_text_feature('data/text_features/matterport3d.npy', MATTERPORT_LABELS)
extract_text_feature('data/text_features/scannet18.npy', SCANNET18_LABELS)
extract_text_feature('data/text_features/scannetpp84.npy', SCANNETPP84_LABELS)
extract_text_feature('data/text_features/arkit.npy', ARKIT_LABELS)