|
|
''' |
|
|
This script is used to generate the semantic labels for the objects in the scene. |
|
|
''' |
|
|
from utils.config import get_dataset, get_args |
|
|
from dataset.scannet import ScanNetDataset |
|
|
import os |
|
|
import numpy as np |
|
|
import open_clip |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import torch |
|
|
from torch.utils.data import TensorDataset, DataLoader |
|
|
from tqdm import tqdm |
|
|
from utils.config import update_args |
|
|
import argparse |
|
|
|
|
|
|
|
|
LEVELS = 3 |
|
|
|
|
|
def load_clip(device): |
|
|
print(f'[INFO] loading CLIP model...') |
|
|
model, _, preprocess = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k") |
|
|
model.to(device) |
|
|
model.eval() |
|
|
print(f'[INFO]', ' finish loading CLIP model...') |
|
|
return model, preprocess |
|
|
|
|
|
def box_multi_level(bbox, shape, level, expansion_ratio): |
|
|
left, top, right, bottom = bbox |
|
|
|
|
|
|
|
|
if level == 0: |
|
|
return left, top, right , bottom |
|
|
|
|
|
x_exp = int(abs(right - left)*expansion_ratio) * level |
|
|
y_exp = int(abs(bottom - top)*expansion_ratio) * level |
|
|
return max(0, left - x_exp), max(0, top - y_exp), min(shape[1], right + x_exp), min(shape[0], bottom + y_exp) |
|
|
|
|
|
def get_cropped_images(object, scene_id, preprocess, num_images=5, expansion_ratio=0.1, preloaded_images=None): |
|
|
croped_images = [] |
|
|
images = [] |
|
|
for frame in object['frames'][:num_images]: |
|
|
image = preloaded_images[frame['frame_id']] |
|
|
|
|
|
x1, y1, x2, y2 = frame['bbox'] |
|
|
for level in range(LEVELS): |
|
|
x1, y1, x2, y2 = box_multi_level((x1, y1, x2, y2), np.asarray(image).shape, level, expansion_ratio) |
|
|
pil_image = pad_into_square(Image.fromarray(cv2.cvtColor(image[y1:y2, x1:x2], cv2.COLOR_BGR2RGB))) |
|
|
croped_images.append(preprocess(pil_image)) |
|
|
images.append(np.asarray(pil_image.resize((64, 64)))) |
|
|
return torch.stack(croped_images), np.concatenate(images, axis=1)[..., ::-1] |
|
|
|
|
|
def pad_into_square(image): |
|
|
width, height = image.size |
|
|
new_size = max(width, height) |
|
|
new_image = Image.new("RGB", (new_size, new_size), (255,255,255)) |
|
|
left = (new_size - width) // 2 |
|
|
top = (new_size - height) // 2 |
|
|
new_image.paste(image, (left, top)) |
|
|
return new_image |
|
|
|
|
|
def preload_images(scene, frame_ids): |
|
|
preloaded_images = {} |
|
|
for frame_id in frame_ids: |
|
|
img_path = f'/home/jovyan/users/lemeshko/mmdetection3d/data/scannet/posed_images/{scene}/{str(frame_id).zfill(5)}.jpg' |
|
|
image = cv2.imread(img_path) |
|
|
preloaded_images[frame_id] = image |
|
|
return preloaded_images |
|
|
|
|
|
def custom_probs(feature, label_text_features): |
|
|
object_feature = feature |
|
|
print(object_feature.shape, np.mean(feature, axis=0, keepdims=True).shape) |
|
|
|
|
|
raw_similarity = object_feature @ label_text_features.T |
|
|
raw_similarity = np.sum(raw_similarity, axis=0, keepdims=True) |
|
|
|
|
|
exp_sim = np.exp(raw_similarity) |
|
|
prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True) |
|
|
probs = np.max(prob, axis=0) |
|
|
max_label_id = np.argmax(probs) |
|
|
prob = probs[max_label_id] |
|
|
return prob, max_label_id |
|
|
|
|
|
def main(args): |
|
|
scenes = np.loadtxt('/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/splits/scannet200_subset.txt', dtype=str) |
|
|
scenes = scenes[args.job_id::args.num_workers] |
|
|
|
|
|
|
|
|
for scene in tqdm(scenes): |
|
|
|
|
|
frame_ids = set() |
|
|
|
|
|
args.seq_name = scene |
|
|
|
|
|
pred_dir = os.path.join('data/prediction', args.config, args.exp_name) |
|
|
|
|
|
if os.path.exists(f'{pred_dir}/{args.seq_name}.npz'): |
|
|
continue |
|
|
dataset = ScanNetDataset(scene) |
|
|
total_point_num = dataset.get_scene_points().shape[0] |
|
|
|
|
|
label_features_dict = dataset.get_label_features() |
|
|
label_text_features = np.stack(list(label_features_dict.values())) |
|
|
descriptions = list(label_features_dict.keys()) |
|
|
|
|
|
scene_name = dataset.seq_name |
|
|
logs_path = f'logs/{args.exp_name}/{scene}' |
|
|
object_dict = np.load(os.path.join(args.path_to_predictions, scene_name, 'infos.npy'), allow_pickle=True).item() |
|
|
|
|
|
label2id = dataset.get_label_id()[0] |
|
|
print(label2id) |
|
|
print(dataset.get_label_id()[1]) |
|
|
|
|
|
os.makedirs(pred_dir, exist_ok=True) |
|
|
num_instance = len(object_dict) |
|
|
pred_dict = { |
|
|
"pred_masks": np.zeros((total_point_num, num_instance), dtype=bool), |
|
|
"pred_score": np.ones(num_instance), |
|
|
"pred_classes" : np.zeros(num_instance, dtype=np.int32) |
|
|
} |
|
|
|
|
|
for key, object in object_dict.items(): |
|
|
frame_ids.update([a['frame_id'] for a in object['frames']]) |
|
|
frame_ids = list(frame_ids) |
|
|
|
|
|
preloaded_images = preload_images(scene, frame_ids) |
|
|
print(scene_name) |
|
|
|
|
|
|
|
|
for idx, (key, object) in enumerate(object_dict.items()): |
|
|
croped_images, saved_images = get_cropped_images(object, scene_name, args.preprocess, args.num_images, preloaded_images=preloaded_images) |
|
|
|
|
|
|
|
|
bs = 32 |
|
|
chunks = torch.chunk(croped_images, max(1, len(croped_images) // bs)) |
|
|
|
|
|
features = [] |
|
|
for images in chunks: |
|
|
|
|
|
images = images.to(args.device) |
|
|
with torch.no_grad(): |
|
|
image_features = args.model.encode_image(images).float() |
|
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
|
image_features = image_features.cpu().numpy() |
|
|
for f in image_features: |
|
|
features.append(f) |
|
|
|
|
|
feature = np.stack(features) |
|
|
object_feature = np.mean(feature, axis=0, keepdims=True) |
|
|
|
|
|
raw_similarity = np.dot(object_feature, label_text_features.T) |
|
|
exp_sim = np.exp(raw_similarity * 100) |
|
|
prob = exp_sim / np.sum(exp_sim, axis=1, keepdims=True) |
|
|
probs = np.max(prob, axis=0) |
|
|
max_label_id = np.argmax(probs) |
|
|
prob = probs[max_label_id] |
|
|
pred_dict['pred_score'][idx] = prob |
|
|
|
|
|
label_id = label2id[descriptions[max_label_id]] |
|
|
pred_dict['pred_classes'][idx] = label_id |
|
|
|
|
|
point_ids = object['mask'] |
|
|
binary_mask = np.zeros(total_point_num, dtype=bool) |
|
|
binary_mask[list(point_ids)] = True |
|
|
pred_dict['pred_masks'][:, idx] = binary_mask |
|
|
if args.debug: |
|
|
os.makedirs(logs_path, exist_ok=True) |
|
|
cv2.imwrite(f'{logs_path}/{str(key).zfill(5)}_{dataset.get_label_id()[1][label_id]}({prob:.2f}).jpg', saved_images) |
|
|
print(key, label_id, dataset.get_label_id()[1][label_id], "confidence:", prob) |
|
|
pred_classes = pred_dict['pred_classes'] |
|
|
pred_classes = [dataset.get_label_id()[1][i] for i in pred_classes] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
np.savez(f'{pred_dir}/{args.seq_name}.npz', **pred_dict) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--config', type=str, default='scannet') |
|
|
parser.add_argument('--debug', action="store_true") |
|
|
parser.add_argument('--exp_name', type=str, default='baseline') |
|
|
parser.add_argument('--device', type=str, default='cuda:0') |
|
|
parser.add_argument('--num_images', type=int, default=5) |
|
|
parser.add_argument('--num_workers', '-n', type=int, default=1) |
|
|
parser.add_argument('--path_to_predictions', type=str, default='/home/jovyan/users/bulat/workspace/3drec/Indoor/Grounded-SAM-2/results/gsam_result/scannet200/yolo/memory_150_classes_198') |
|
|
parser.add_argument('--job_id', '-i', type=int, default=0) |
|
|
args = parser.parse_args() |
|
|
args = update_args(args) |
|
|
model, preprocess = load_clip(args.device) |
|
|
args.model = model |
|
|
args.preprocess = preprocess |
|
|
main(args) |
|
|
|