|
|
''' |
|
|
This script is used to generate the semantic labels for the objects in the scene. |
|
|
''' |
|
|
from utils.config import get_dataset, get_args |
|
|
import os |
|
|
import numpy as np |
|
|
|
|
|
def sigmoid(x): |
|
|
return 1 / (1 + np.exp(-x)) |
|
|
|
|
|
def cos_sim(x_cos): |
|
|
return (1 + x_cos) / 2 |
|
|
|
|
|
def main(args): |
|
|
dataset = get_dataset(args) |
|
|
total_point_num = dataset.get_scene_points().shape[0] |
|
|
|
|
|
label_features_dict = dataset.get_label_features() |
|
|
label_text_features = np.stack(list(label_features_dict.values())) |
|
|
descriptions = list(label_features_dict.keys()) |
|
|
|
|
|
object_dict = np.load(f'{dataset.object_dict_dir}/{args.config}/object_dict.npy', allow_pickle=True).item() |
|
|
clip_feature = np.load(f'{dataset.object_dict_dir}/{args.config}/open-vocabulary_features.npy', allow_pickle=True).item() |
|
|
label2id = dataset.get_label_id()[0] |
|
|
|
|
|
num_instance = len(object_dict) |
|
|
pred_dict = { |
|
|
"pred_masks": np.zeros((total_point_num, num_instance), dtype=bool), |
|
|
"pred_score": np.ones(num_instance), |
|
|
"pred_classes" : np.zeros(num_instance, dtype=np.int32) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for idx, (key, value) in enumerate(object_dict.items()): |
|
|
repre_mask_list = value['repre_mask_list'] |
|
|
if len(repre_mask_list) == 0: |
|
|
continue |
|
|
|
|
|
feature_list = [] |
|
|
feature_list = [clip_feature[f'{mask_info[0]}_{mask_info[1]}'] for mask_info in repre_mask_list] |
|
|
feature = np.stack(feature_list) |
|
|
object_feature = np.mean(feature, axis=0, keepdims=True) |
|
|
|
|
|
raw_similarity = np.dot(object_feature, label_text_features.T) |
|
|
exp_sim = cos_sim(raw_similarity) |
|
|
prob = exp_sim |
|
|
probs = np.max(prob, axis=0) |
|
|
max_label_id = np.argmax(probs) |
|
|
prob = probs[max_label_id] |
|
|
pred_dict['pred_score'][idx] = prob |
|
|
|
|
|
label_id = label2id[descriptions[max_label_id]] |
|
|
pred_dict['pred_classes'][idx] = label_id |
|
|
|
|
|
point_ids = object_dict[key]['point_ids'] |
|
|
binary_mask = np.zeros(total_point_num, dtype=bool) |
|
|
binary_mask[list(point_ids)] = True |
|
|
pred_dict['pred_masks'][:, idx] = binary_mask |
|
|
|
|
|
np.savez(f'{dataset.object_dict_dir}/prediction.npz', **pred_dict) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = get_args() |
|
|
main(args) |