import open3d as o3d import numpy as np import os import cv2 from evaluation.constants import SCANNET_LABELS, SCANNET_IDS, SCANNET18_LABELS, SCANNET18_IDS, SCANNETPP84_IDS, SCANNETPP84_LABELS, SCANNET20_LABELS, SCANNET20_IDS, ARKIT_LABELS, ARKIT_IDS class ScanNetDataset: def __init__(self, seq_name, root='data/scannet', use_templates=False) -> None: self.seq_name = seq_name self.use_templates = use_templates self.root = os.path.join(root, 'processed', seq_name) self.rgb_dir = f'{self.root}/color' self.depth_dir = f'{self.root}/depth' self.segmentation_dir = f'{self.root}/output/mask' self.object_dict_dir = f'{self.root}/output/object' self.point_cloud_path = f'{self.root}/{seq_name}_vh_clean_2.ply' self.mesh_path = self.point_cloud_path self.extrinsics_dir = f'{self.root}/pose' self.intrinsic_dir = f'{self.root}/intrinsic' self.label_features_dict = None self.depth_scale = 1000.0 self.image_size = self.get_image_size() self.depth_size = self.get_depth_shape() def get_frame_list(self, stride): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('.')[0])) end = int(image_list[-1].split('.')[0]) + 1 frame_id_list = [int(a.split('.')[0]) for a in image_list] return list(frame_id_list) def get_image_size(self): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('.')[0])) image_path = os.path.join(self.rgb_dir, image_list[0]) image = cv2.imread(image_path) return image.shape[:2][::-1] def get_depth_shape(self): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('.')[0])) depth_path = os.path.join(self.depth_dir, f"{image_list[0].split('.')[0]}.png") depth = cv2.imread(depth_path, -1) return depth.shape[:2][::-1] def get_intrinsics(self, frame_id): intrinsic_path = f'{self.intrinsic_dir}/intrinsic_depth.txt' intrinsics = np.loadtxt(intrinsic_path) intrinisc_cam_parameters = o3d.camera.PinholeCameraIntrinsic() intrinisc_cam_parameters.set_intrinsics(self.image_size[0], self.image_size[1], intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2]) return intrinisc_cam_parameters def get_extrinsic(self, frame_id): pose_path = os.path.join(self.extrinsics_dir, str(frame_id) + '.txt') pose = np.loadtxt(pose_path) return pose def get_depth(self, frame_id): depth_path = os.path.join(self.depth_dir, str(frame_id) + '.png') depth = cv2.imread(depth_path, -1) depth = depth / self.depth_scale depth = depth.astype(np.float32) return depth def get_rgb(self, frame_id, change_color=True): rgb_path = os.path.join(self.rgb_dir, str(frame_id) + '.jpg') rgb = cv2.imread(rgb_path) if change_color: rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) return rgb def get_segmentation(self, frame_id, align_with_depth=False): segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png') if not os.path.exists(segmentation_path): assert False, f"Segmentation not found: {segmentation_path}" segmentation = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED) if align_with_depth: segmentation = cv2.resize(segmentation, self.depth_size, interpolation=cv2.INTER_NEAREST) return segmentation def get_frame_path(self, frame_id): rgb_path = os.path.join(self.rgb_dir, str(frame_id) + '.jpg') segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png') return rgb_path, segmentation_path def get_label_features(self): if self.label_features_dict is None: if self.use_templates: label_features_dict = np.load(f'data/text_features/scannet_templates.npy', allow_pickle=True).item() else: label_features_dict = np.load(f'data/text_features/scannet.npy', allow_pickle=True).item() self.label_features_dict = label_features_dict return self.label_features_dict def get_scene_points(self): mesh = o3d.io.read_point_cloud(self.point_cloud_path) vertices = np.asarray(mesh.points) return vertices def get_label_id(self): self.class_id = SCANNET_IDS self.class_label = SCANNET_LABELS self.label2id = {} self.id2label = {} for label, id in zip(self.class_label, self.class_id): self.label2id[label] = id self.id2label[id] = label return self.label2id, self.id2label class ARKitDataset(ScanNetDataset): def __init__(self, seq_name, root='data/arkit_dust3r_posed'): super().__init__(seq_name, root) self.image_size = self.get_image_size() def get_image_size(self): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('.')[0])) image_path = os.path.join(self.rgb_dir, image_list[0]) image = cv2.imread(image_path) return image.shape[:2][::-1] def get_frame_list(self, stride): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('.')[0])) end = int(image_list[-1].split('.')[0]) + 1 frame_id_list = [a.split('.')[0] for a in image_list] return list(frame_id_list) def get_label_id(self): self.class_id = ARKIT_IDS self.class_label = ARKIT_LABELS self.label2id = {} self.id2label = {} for label, id in zip(self.class_label, self.class_id): self.label2id[label] = id self.id2label[id] = label return self.label2id, self.id2label def get_label_features(self): label_features_dict = np.load(f'data/text_features/arkit.npy', allow_pickle=True).item() return label_features_dict class ITWDataset(ARKitDataset): def get_image_size(self): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('_')[0])) image_path = os.path.join(self.rgb_dir, image_list[0]) image = cv2.imread(image_path) return image.shape[:2][::-1] def get_depth_shape(self): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('_')[0])) depth_path = os.path.join(self.depth_dir, f"{image_list[0].split('.')[0]}.png") depth = cv2.imread(depth_path, -1) return depth.shape[:2][::-1] def get_frame_list(self, stride): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('_')[0])) frame_id_list = [a.split('.')[0] for a in image_list] return list(frame_id_list) def get_label_features(self): label_features_dict = np.load(f'{self.root}/text_features.npy', allow_pickle=True).item() return label_features_dict def get_label_id(self): text_features = self.get_label_features() self.class_label = list(text_features.keys()) self.class_id = list(range(len(self.class_label))) self.label2id = {} self.id2label = {} for label, id in zip(self.class_label, self.class_id): self.label2id[label] = id self.id2label[id] = label return self.label2id, self.id2label class WildDataset(ARKitDataset): def __init__(self, seq_name, root): self.root = os.path.join(root, seq_name) self.rgb_dir = f'{self.root}/images' self.depth_dir = f'{self.root}/depth' self.segmentation_dir = f'{self.root}/output/mask' self.object_dict_dir = f'{self.root}/output/object' self.point_cloud_path = f'{self.root}/point_cloud.ply' self.mesh_path = self.point_cloud_path self.extrinsics_dir = f'{self.root}/pose' self.intrinsic_dir = f'{self.root}/intrinsic' self.label_features_dict = None self.depth_scale = 1000.0 self.image_size = self.get_depth_shape() self.depth_size = self.get_depth_shape() def get_label_features(self): label_features_dict = np.load(f'{self.root}/text_features.npy', allow_pickle=True).item() return label_features_dict def get_segmentation(self, frame_id, align_with_depth=False): segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png') if not os.path.exists(segmentation_path): assert False, f"Segmentation not found: {segmentation_path}" segmentation = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED) segmentation = cv2.resize(segmentation, self.depth_size, interpolation=cv2.INTER_NEAREST) return segmentation def get_label_id(self): text_features = self.get_label_features() self.class_label = list(text_features.keys()) self.class_id = list(range(len(self.class_label))) self.label2id = {} self.id2label = {} for label, id in zip(self.class_label, self.class_id): self.label2id[label] = id self.id2label[id] = label return self.label2id, self.id2label class ScannetPP2Dataset(ScanNetDataset): def __init__(self, seq_name, root='data/scannetpp_dust3r_posed'): super().__init__(seq_name, root) self.image_size = self.get_image_size() self.depth_size = self.get_depth_shape() self.point_cloud_path = f'{self.root}/{seq_name}.ply' def get_image_size(self): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('.')[0].split('_')[1])) image_path = os.path.join(self.rgb_dir, image_list[0]) image = cv2.imread(image_path) return image.shape[:2][::-1] def get_depth_shape(self): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('.')[0].split('_')[1])) depth_path = os.path.join(self.depth_dir, f"{image_list[0].split('.')[0]}.png") depth = cv2.imread(depth_path, -1) return depth.shape[:2][::-1] def get_frame_list(self, stride): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('.')[0].split('_')[1])) frame_id_list = [a.split('.')[0] for a in image_list] return list(frame_id_list) def get_segmentation(self, frame_id, align_with_depth=False): segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png') if not os.path.exists(segmentation_path): assert False, f"Segmentation not found: {segmentation_path}" segmentation = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED) segmentation = cv2.resize(segmentation, self.depth_size, interpolation=cv2.INTER_NEAREST) return segmentation def get_label_id(self): self.class_id = SCANNETPP84_IDS self.class_label = SCANNETPP84_LABELS self.label2id = {} self.id2label = {} for label, id in zip(self.class_label, self.class_id): self.label2id[label] = id self.id2label[id] = label return self.label2id, self.id2label def get_label_features(self): label_features_dict = np.load(f'data/text_features/scannetpp84.npy', allow_pickle=True).item() return label_features_dict def get_depth(self, frame_id): depth_path = os.path.join(self.depth_dir, str(frame_id) + '.png') depth = cv2.imread(depth_path, -1) depth = depth / self.depth_scale depth = depth.astype(np.float32) return depth def get_intrinsics(self, frame_id): intrinsic_path = f'{self.intrinsic_dir}/intrinsic_depth.txt' intrinsics = np.loadtxt(intrinsic_path) intrinisc_cam_parameters = o3d.camera.PinholeCameraIntrinsic() intrinisc_cam_parameters.set_intrinsics(self.image_size[0], self.image_size[1], intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2]) return intrinisc_cam_parameters class ScanNet18Dataset: def __init__(self, seq_name, root='data/scannet') -> None: self.seq_name = seq_name self.root = os.path.join(root, 'processed', seq_name) self.rgb_dir = f'{self.root}/color' self.depth_dir = f'{self.root}/depth' self.segmentation_dir = f'{self.root}/output/mask' self.object_dict_dir = f'{self.root}/output/object' self.point_cloud_path = f'{self.root}/{seq_name}.ply' self.mesh_path = self.point_cloud_path self.extrinsics_dir = f'{self.root}/pose' self.intrinsic_dir = f'{self.root}/intrinsic' self.depth_scale = 1000.0 self.image_size = self.get_image_size() self.depth_size = self.get_depth_shape() def get_frame_list(self, stride): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('.')[0])) end = int(image_list[-1].split('.')[0]) + 1 frame_id_list = [a.split('.')[0] for a in image_list] return list(frame_id_list) def get_image_size(self): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('.')[0])) image_path = os.path.join(self.rgb_dir, image_list[0]) image = cv2.imread(image_path) return image.shape[:2][::-1] def get_depth_shape(self): image_list = os.listdir(self.rgb_dir) image_list = sorted(image_list, key=lambda x: int(x.split('.')[0])) depth_path = os.path.join(self.depth_dir, f"{image_list[0].split('.')[0]}.png") depth = cv2.imread(depth_path, -1) return depth.shape[:2][::-1] def get_intrinsics(self, frame_id): intrinsic_path = f'{self.intrinsic_dir}/intrinsic_depth.txt' intrinsics = np.loadtxt(intrinsic_path) intrinisc_cam_parameters = o3d.camera.PinholeCameraIntrinsic() intrinisc_cam_parameters.set_intrinsics(self.image_size[0], self.image_size[1], intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2]) return intrinisc_cam_parameters def get_extrinsic(self, frame_id): pose_path = os.path.join(self.extrinsics_dir, str(frame_id) + '.txt') pose = np.loadtxt(pose_path) return pose def get_depth(self, frame_id): depth_path = os.path.join(self.depth_dir, str(frame_id) + '.png') depth = cv2.imread(depth_path, -1) depth = depth / self.depth_scale depth = depth.astype(np.float32) return depth def get_rgb(self, frame_id, change_color=True): rgb_path = os.path.join(self.rgb_dir, str(frame_id) + '.jpg') rgb = cv2.imread(rgb_path) if change_color: rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) return rgb def get_segmentation(self, frame_id, align_with_depth=False): segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png') if not os.path.exists(segmentation_path): assert False, f"Segmentation not found: {segmentation_path}" segmentation = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED) segmentation = cv2.resize(segmentation, self.depth_size, interpolation=cv2.INTER_NEAREST) return segmentation def get_frame_path(self, frame_id): rgb_path = os.path.join(self.rgb_dir, str(frame_id) + '.jpg') segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png') return rgb_path, segmentation_path def get_label_features(self): label_features_dict = np.load(f'data/text_features/scannet18.npy', allow_pickle=True).item() return label_features_dict def get_scene_points(self): mesh = o3d.io.read_point_cloud(self.point_cloud_path) vertices = np.asarray(mesh.points) return vertices def get_label_id(self): self.class_id = SCANNET18_IDS self.class_label = SCANNET18_LABELS self.label2id = {} self.id2label = {} for label, id in zip(self.class_label, self.class_id): self.label2id[label] = id self.id2label[id] = label return self.label2id, self.id2label class ScanNet20Dataset(ScanNet18Dataset): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.point_cloud_path = f'{self.root}/{self.seq_name}_vh_clean_2.ply' def get_label_features(self): label_features_dict = np.load(f'/home/jovyan/users/lemeshko/Indoor/MaskClustering/data/text_features/scannet20.npy', allow_pickle=True).item() return label_features_dict def get_label_id(self): self.class_id = SCANNET20_IDS self.class_label = SCANNET20_LABELS self.label2id = {} self.id2label = {} for label, id in zip(self.class_label, self.class_id): self.label2id[label] = id self.id2label[id] = label return self.label2id, self.id2label