zoo3d / MaskClustering /semantics /3dtrack_open-voc_query.py
bulatko's picture
adding real MK
55e58d1
'''
This script is used to generate the semantic labels for the objects in the scene.
'''
from utils.config import get_dataset, get_args
from dataset.scannet import ScanNetDataset
import os
import numpy as np
import open_clip
import cv2
from PIL import Image
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from utils.config import update_args
import argparse
LEVELS = 3
def load_clip(device):
print(f'[INFO] loading CLIP model...')
model, _, preprocess = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
model.to(device)
model.eval()
print(f'[INFO]', ' finish loading CLIP model...')
return model, preprocess
def box_multi_level(bbox, shape, level, expansion_ratio):
left, top, right, bottom = bbox
if level == 0:
return left, top, right , bottom
x_exp = int(abs(right - left)*expansion_ratio * level)
y_exp = int(abs(bottom - top)*expansion_ratio * level)
return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp)
def get_cropped_images(object, scene_id, preprocess, preloaded_images, num_images=5, expansion_ratio=0.1):
croped_images = []
images = []
for frame in object['frames'][:num_images]:
image = preloaded_images[frame['frame_id']]
x1, y1, x2, y2 = frame['bbox']
x1, y1, x2, y2 = np.round([x1, y1, x2, y2]).astype(int)
for level in range(LEVELS):
x1_, y1_, x2_, y2_ = box_multi_level((x1, y1, x2, y2), np.asarray(image).shape, level, expansion_ratio)
if x1_ == x2_ or y1_ == y2_:
continue
pil_image = pad_into_square(Image.fromarray(cv2.cvtColor(image[y1_:y2_, x1_:x2_], cv2.COLOR_BGR2RGB)))
croped_images.append(preprocess(pil_image))
images.append(np.asarray(pil_image.resize((64, 64))))
if len(croped_images) == 0:
return None, None
return torch.stack(croped_images), np.concatenate(images, axis=1)[..., ::-1]
def pad_into_square(image):
width, height = image.size
new_size = max(width, height)
new_image = Image.new("RGB", (new_size, new_size), (255,255,255))
left = (new_size - width) // 2
top = (new_size - height) // 2
new_image.paste(image, (left, top))
return new_image
def preload_images(scene, frame_ids):
preloaded_images = {}
for frame_id in frame_ids:
img_path = f'/home/jovyan/users/lemeshko/mmdetection3d/data/scannet/posed_images/{scene}/{str(frame_id).zfill(5)}.jpg'
image = cv2.imread(img_path)
preloaded_images[frame_id] = image
return preloaded_images
def custom_probs(feature, label_text_features):
object_feature = feature #np.mean(feature, axis=0, keepdims=True)
print(object_feature.shape, np.mean(feature, axis=0, keepdims=True).shape)
raw_similarity = object_feature @ label_text_features.T
raw_similarity = np.sum(raw_similarity, axis=0, keepdims=True)
exp_sim = np.exp(raw_similarity)
prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True)
probs = np.max(prob, axis=0)
max_label_id = np.argmax(probs)
prob = probs[max_label_id]
return prob, max_label_id
def main(args):
scenes = np.loadtxt('/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/splits/scannet200_subset.txt', dtype=str)
# scenes = scenes[:1]
# scenes = ["scene0011_00"]
for scene in tqdm(scenes):
frame_ids = set()
args.seq_name = scene
pred_dir = os.path.join('data/prediction', args.config, args.exp_name)
if os.path.exists(f'{pred_dir}/{args.seq_name}.npz'):
continue
dataset = ScanNetDataset(scene, use_templates=True)
total_point_num = dataset.get_scene_points().shape[0]
label_features_dict = dataset.get_label_features()
label_text_features = np.stack(list(label_features_dict.values()))
descriptions = list(label_features_dict.keys())
scene_name = dataset.seq_name
logs_path = f'logs/{args.exp_name}/{scene}'
# Загружаем данные в новом формате (список объектов)
object_list = np.load(f"/home/jovyan/users/bulat/workspace/3drec/det/OV/mask_proj/outputs/30/{scene}/mask_data.npy", allow_pickle=True)
# print(f'[INFO] loaded {object_list} objects')
label2id = dataset.get_label_id()[0]
os.makedirs(pred_dir, exist_ok=True)
num_instance = len(object_list)
pred_dict = {
"pred_masks": np.zeros((total_point_num, num_instance), dtype=bool),
"pred_score": np.ones(num_instance),
"pred_classes" : np.zeros(num_instance, dtype=np.int32)
}
# Собираем все frame_ids из всех объектов
for object_data in object_list:
frame_ids.update([frame['frame_id'] for frame in object_data['frames']])
frame_ids = list(frame_ids)
preloaded_images = preload_images(scene, frame_ids)
# Исправляем итерацию по объектам
for idx, object_data in enumerate(object_list):
croped_images, saved_images = get_cropped_images(object_data, scene_name, args.preprocess, preloaded_images=preloaded_images, num_images=args.num_images)
if croped_images is None:
print(f'[INFO] no croped images for object {idx}')
continue
bs = 32
chunks = torch.chunk(croped_images, max(1, len(croped_images) // bs))
features = []
for images in chunks:
# images = images[0]
images = images.to(args.device)
with torch.no_grad():
image_features = args.model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.cpu().numpy()
for f in image_features:
features.append(f)
feature = np.stack(features)
object_feature = np.mean(feature, axis=0, keepdims=True)
raw_similarity = np.dot(object_feature, label_text_features.T)
exp_sim = np.exp(raw_similarity * 100)
prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True)
probs = np.max(prob, axis=0)
max_label_id = np.argmax(probs)
prob = probs[max_label_id]
pred_dict['pred_score'][idx] = prob
label_id = label2id[descriptions[max_label_id]]
pred_dict['pred_classes'][idx] = label_id
point_ids = object_data['mask']
binary_mask = np.zeros(total_point_num, dtype=bool)
binary_mask[list(point_ids)] = True
pred_dict['pred_masks'][:, idx] = binary_mask
if args.debug:
os.makedirs(logs_path, exist_ok=True)
cv2.imwrite(f'{logs_path}/{str(idx).zfill(5)}_{dataset.get_label_id()[1][label_id]}({prob:.2f}).jpg', saved_images)
print(idx, label_id, dataset.get_label_id()[1][label_id], "confidence:", prob)
pred_classes = pred_dict['pred_classes']
# pred_classes = [dataset.get_label_id()[1][i] for i in pred_classes]
pred_dict['pred_masks'] = pred_dict['pred_masks'][:, pred_classes != 0]
pred_dict['pred_score'] = pred_dict['pred_score'][pred_classes != 0]
pred_dict['pred_classes'] = pred_dict['pred_classes'][pred_classes != 0]
np.savez(f'{pred_dir}/{args.seq_name}.npz', **pred_dict)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='scannet')
parser.add_argument('--debug', action="store_true")
parser.add_argument('--exp_name', type=str, default='baseline')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--num_images', type=int, default=5)
parser.add_argument('--num_workers', '-n', type=int, default=1)
parser.add_argument('--path_to_predictions', type=str, default='/home/jovyan/users/bulat/workspace/3drec/Indoor/Grounded-SAM-2/results/gsam_result/scannet200/yolo/memory_150_classes_198')
parser.add_argument('--job_id', '-i', type=int, default=0)
args = parser.parse_args()
args = update_args(args)
model, preprocess = load_clip(args.device)
args.model = model
args.preprocess = preprocess
main(args)