zoo3d / MaskClustering /semantics /c_open-voc_query.py
bulatko's picture
adding real MK
55e58d1
'''
This script is used to generate the semantic labels for the objects in the scene.
'''
from utils.config import get_dataset, get_args
from dataset.scannet import ScanNetDataset
import os
import numpy as np
import open_clip
import cv2
from PIL import Image
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from utils.config import update_args
import argparse
LEVELS = 3
def load_clip(device):
print(f'[INFO] loading CLIP model...')
model, _, preprocess = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
model.to(device)
model.eval()
print(f'[INFO]', ' finish loading CLIP model...')
return model, preprocess
def box_multi_level(bbox, shape, level, expansion_ratio):
left, top, right, bottom = bbox
if level == 0:
return left, top, right , bottom
x_exp = int(abs(right - left)*expansion_ratio) * level
y_exp = int(abs(bottom - top)*expansion_ratio) * level
return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp)
def get_cropped_images(object, scene_id, preprocess, num_images=5, expansion_ratio=0.1, preloaded_images=None):
croped_images = []
images = []
for frame in object['frames'][:num_images]:
image = preloaded_images[frame['frame_id']]
x1, y1, x2, y2 = frame['bbox']
for level in range(LEVELS):
x1, y1, x2, y2 = box_multi_level((x1, y1, x2, y2), np.asarray(image).shape, level, expansion_ratio)
pil_image = pad_into_square(Image.fromarray(cv2.cvtColor(image[y1:y2, x1:x2], cv2.COLOR_BGR2RGB)))
croped_images.append(preprocess(pil_image))
images.append(np.asarray(pil_image.resize((64, 64))))
return torch.stack(croped_images), np.concatenate(images, axis=1)[..., ::-1]
def pad_into_square(image):
width, height = image.size
new_size = max(width, height)
new_image = Image.new("RGB", (new_size, new_size), (255,255,255))
left = (new_size - width) // 2
top = (new_size - height) // 2
new_image.paste(image, (left, top))
return new_image
def preload_images(scene, frame_ids):
preloaded_images = {}
for frame_id in frame_ids:
img_path = f'/home/jovyan/users/lemeshko/mmdetection3d/data/scannet/posed_images/{scene}/{str(frame_id).zfill(5)}.jpg'
image = cv2.imread(img_path)
preloaded_images[frame_id] = image
return preloaded_images
def custom_probs(feature, label_text_features):
object_feature = feature #np.mean(feature, axis=0, keepdims=True)
print(object_feature.shape, np.mean(feature, axis=0, keepdims=True).shape)
raw_similarity = object_feature @ label_text_features.T
raw_similarity = np.sum(raw_similarity, axis=0, keepdims=True)
exp_sim = np.exp(raw_similarity)
prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True)
probs = np.max(prob, axis=0)
max_label_id = np.argmax(probs)
prob = probs[max_label_id]
return prob, max_label_id
def main(args):
scenes = np.loadtxt('/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/splits/scannet200_subset.txt', dtype=str)
scenes = scenes[args.job_id::args.num_workers]
# scenes = scenes[:1]
for scene in tqdm(scenes):
frame_ids = set()
args.seq_name = scene
pred_dir = os.path.join('data/prediction', args.config, args.exp_name)
if os.path.exists(f'{pred_dir}/{args.seq_name}.npz'):
continue
dataset = ScanNetDataset(scene)
total_point_num = dataset.get_scene_points().shape[0]
label_features_dict = dataset.get_label_features()
label_text_features = np.stack(list(label_features_dict.values()))
descriptions = list(label_features_dict.keys())
scene_name = dataset.seq_name
logs_path = f'logs/{args.exp_name}/{scene}'
object_dict = np.load(os.path.join(args.path_to_predictions, scene_name, 'infos.npy'), allow_pickle=True).item()
label2id = dataset.get_label_id()[0]
print(label2id)
print(dataset.get_label_id()[1])
os.makedirs(pred_dir, exist_ok=True)
num_instance = len(object_dict)
pred_dict = {
"pred_masks": np.zeros((total_point_num, num_instance), dtype=bool),
"pred_score": np.ones(num_instance),
"pred_classes" : np.zeros(num_instance, dtype=np.int32)
}
for key, object in object_dict.items():
frame_ids.update([a['frame_id'] for a in object['frames']])
frame_ids = list(frame_ids)
preloaded_images = preload_images(scene, frame_ids)
print(scene_name)
for idx, (key, object) in enumerate(object_dict.items()):
croped_images, saved_images = get_cropped_images(object, scene_name, args.preprocess, args.num_images, preloaded_images=preloaded_images)
bs = 32
chunks = torch.chunk(croped_images, max(1, len(croped_images) // bs))
features = []
for images in chunks:
# images = images[0]
images = images.to(args.device)
with torch.no_grad():
image_features = args.model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.cpu().numpy()
for f in image_features:
features.append(f)
feature = np.stack(features)
object_feature = np.mean(feature, axis=0, keepdims=True)
raw_similarity = np.dot(object_feature, label_text_features.T)
exp_sim = np.exp(raw_similarity * 100)
prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True)
probs = np.max(prob, axis=0)
max_label_id = np.argmax(probs)
prob = probs[max_label_id]
pred_dict['pred_score'][idx] = prob
label_id = label2id[descriptions[max_label_id]]
pred_dict['pred_classes'][idx] = label_id
point_ids = object['mask']
binary_mask = np.zeros(total_point_num, dtype=bool)
binary_mask[list(point_ids)] = True
pred_dict['pred_masks'][:, idx] = binary_mask
if args.debug:
os.makedirs(logs_path, exist_ok=True)
cv2.imwrite(f'{logs_path}/{str(key).zfill(5)}_{dataset.get_label_id()[1][label_id]}({prob:.2f}).jpg', saved_images)
print(key, label_id, dataset.get_label_id()[1][label_id], "confidence:", prob)
pred_classes = pred_dict['pred_classes']
pred_classes = [dataset.get_label_id()[1][i] for i in pred_classes]
# remove classes with label == 0
np.savez(f'{pred_dir}/{args.seq_name}.npz', **pred_dict)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='scannet')
parser.add_argument('--debug', action="store_true")
parser.add_argument('--exp_name', type=str, default='baseline')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--num_images', type=int, default=5)
parser.add_argument('--num_workers', '-n', type=int, default=1)
parser.add_argument('--path_to_predictions', type=str, default='/home/jovyan/users/bulat/workspace/3drec/Indoor/Grounded-SAM-2/results/gsam_result/scannet200/yolo/memory_150_classes_198')
parser.add_argument('--job_id', '-i', type=int, default=0)
args = parser.parse_args()
args = update_args(args)
model, preprocess = load_clip(args.device)
args.model = model
args.preprocess = preprocess
main(args)