bulatko's picture
adding real MK
55e58d1
raw
history blame
17.4 kB
import open3d as o3d
import numpy as np
import os
import cv2
from evaluation.constants import SCANNET_LABELS, SCANNET_IDS, SCANNET18_LABELS, SCANNET18_IDS, SCANNETPP84_IDS, SCANNETPP84_LABELS, SCANNET20_LABELS, SCANNET20_IDS, ARKIT_LABELS, ARKIT_IDS
class ScanNetDataset:
def __init__(self, seq_name, root='data/scannet', use_templates=False) -> None:
self.seq_name = seq_name
self.use_templates = use_templates
self.root = os.path.join(root, 'processed', seq_name)
self.rgb_dir = f'{self.root}/color'
self.depth_dir = f'{self.root}/depth'
self.segmentation_dir = f'{self.root}/output/mask'
self.object_dict_dir = f'{self.root}/output/object'
self.point_cloud_path = f'{self.root}/{seq_name}_vh_clean_2.ply'
self.mesh_path = self.point_cloud_path
self.extrinsics_dir = f'{self.root}/pose'
self.intrinsic_dir = f'{self.root}/intrinsic'
self.label_features_dict = None
self.depth_scale = 1000.0
self.image_size = self.get_image_size()
self.depth_size = self.get_depth_shape()
def get_frame_list(self, stride):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('.')[0]))
end = int(image_list[-1].split('.')[0]) + 1
frame_id_list = [int(a.split('.')[0]) for a in image_list]
return list(frame_id_list)
def get_image_size(self):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('.')[0]))
image_path = os.path.join(self.rgb_dir, image_list[0])
image = cv2.imread(image_path)
return image.shape[:2][::-1]
def get_depth_shape(self):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('.')[0]))
depth_path = os.path.join(self.depth_dir, f"{image_list[0].split('.')[0]}.png")
depth = cv2.imread(depth_path, -1)
return depth.shape[:2][::-1]
def get_intrinsics(self, frame_id):
intrinsic_path = f'{self.intrinsic_dir}/intrinsic_depth.txt'
intrinsics = np.loadtxt(intrinsic_path)
intrinisc_cam_parameters = o3d.camera.PinholeCameraIntrinsic()
intrinisc_cam_parameters.set_intrinsics(self.image_size[0], self.image_size[1], intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2])
return intrinisc_cam_parameters
def get_extrinsic(self, frame_id):
pose_path = os.path.join(self.extrinsics_dir, str(frame_id) + '.txt')
pose = np.loadtxt(pose_path)
return pose
def get_depth(self, frame_id):
depth_path = os.path.join(self.depth_dir, str(frame_id) + '.png')
depth = cv2.imread(depth_path, -1)
depth = depth / self.depth_scale
depth = depth.astype(np.float32)
return depth
def get_rgb(self, frame_id, change_color=True):
rgb_path = os.path.join(self.rgb_dir, str(frame_id) + '.jpg')
rgb = cv2.imread(rgb_path)
if change_color:
rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
return rgb
def get_segmentation(self, frame_id, align_with_depth=False):
segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png')
if not os.path.exists(segmentation_path):
assert False, f"Segmentation not found: {segmentation_path}"
segmentation = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED)
if align_with_depth:
segmentation = cv2.resize(segmentation, self.depth_size, interpolation=cv2.INTER_NEAREST)
return segmentation
def get_frame_path(self, frame_id):
rgb_path = os.path.join(self.rgb_dir, str(frame_id) + '.jpg')
segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png')
return rgb_path, segmentation_path
def get_label_features(self):
if self.label_features_dict is None:
if self.use_templates:
label_features_dict = np.load(f'data/text_features/scannet_templates.npy', allow_pickle=True).item()
else:
label_features_dict = np.load(f'data/text_features/scannet.npy', allow_pickle=True).item()
self.label_features_dict = label_features_dict
return self.label_features_dict
def get_scene_points(self):
mesh = o3d.io.read_point_cloud(self.point_cloud_path)
vertices = np.asarray(mesh.points)
return vertices
def get_label_id(self):
self.class_id = SCANNET_IDS
self.class_label = SCANNET_LABELS
self.label2id = {}
self.id2label = {}
for label, id in zip(self.class_label, self.class_id):
self.label2id[label] = id
self.id2label[id] = label
return self.label2id, self.id2label
class ARKitDataset(ScanNetDataset):
def __init__(self, seq_name, root='data/arkit_dust3r_posed'):
super().__init__(seq_name, root)
self.image_size = self.get_image_size()
def get_image_size(self):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('.')[0]))
image_path = os.path.join(self.rgb_dir, image_list[0])
image = cv2.imread(image_path)
return image.shape[:2][::-1]
def get_frame_list(self, stride):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('.')[0]))
end = int(image_list[-1].split('.')[0]) + 1
frame_id_list = [a.split('.')[0] for a in image_list]
return list(frame_id_list)
def get_label_id(self):
self.class_id = ARKIT_IDS
self.class_label = ARKIT_LABELS
self.label2id = {}
self.id2label = {}
for label, id in zip(self.class_label, self.class_id):
self.label2id[label] = id
self.id2label[id] = label
return self.label2id, self.id2label
def get_label_features(self):
label_features_dict = np.load(f'data/text_features/arkit.npy', allow_pickle=True).item()
return label_features_dict
class ITWDataset(ARKitDataset):
def get_image_size(self):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('_')[0]))
image_path = os.path.join(self.rgb_dir, image_list[0])
image = cv2.imread(image_path)
return image.shape[:2][::-1]
def get_depth_shape(self):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('_')[0]))
depth_path = os.path.join(self.depth_dir, f"{image_list[0].split('.')[0]}.png")
depth = cv2.imread(depth_path, -1)
return depth.shape[:2][::-1]
def get_frame_list(self, stride):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('_')[0]))
frame_id_list = [a.split('.')[0] for a in image_list]
return list(frame_id_list)
def get_label_features(self):
label_features_dict = np.load(f'{self.root}/text_features.npy', allow_pickle=True).item()
return label_features_dict
def get_label_id(self):
text_features = self.get_label_features()
self.class_label = list(text_features.keys())
self.class_id = list(range(len(self.class_label)))
self.label2id = {}
self.id2label = {}
for label, id in zip(self.class_label, self.class_id):
self.label2id[label] = id
self.id2label[id] = label
return self.label2id, self.id2label
class WildDataset(ARKitDataset):
def __init__(self, seq_name, root):
self.root = os.path.join(root, seq_name)
self.rgb_dir = f'{self.root}/images'
self.depth_dir = f'{self.root}/depth'
self.segmentation_dir = f'{self.root}/output/mask'
self.object_dict_dir = f'{self.root}/output/object'
self.point_cloud_path = f'{self.root}/point_cloud.ply'
self.mesh_path = self.point_cloud_path
self.extrinsics_dir = f'{self.root}/pose'
self.intrinsic_dir = f'{self.root}/intrinsic'
self.label_features_dict = None
self.depth_scale = 1000.0
self.image_size = self.get_depth_shape()
self.depth_size = self.get_depth_shape()
def get_label_features(self):
label_features_dict = np.load(f'{self.root}/text_features.npy', allow_pickle=True).item()
return label_features_dict
def get_segmentation(self, frame_id, align_with_depth=False):
segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png')
if not os.path.exists(segmentation_path):
assert False, f"Segmentation not found: {segmentation_path}"
segmentation = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED)
segmentation = cv2.resize(segmentation, self.depth_size, interpolation=cv2.INTER_NEAREST)
return segmentation
def get_label_id(self):
text_features = self.get_label_features()
self.class_label = list(text_features.keys())
self.class_id = list(range(len(self.class_label)))
self.label2id = {}
self.id2label = {}
for label, id in zip(self.class_label, self.class_id):
self.label2id[label] = id
self.id2label[id] = label
return self.label2id, self.id2label
class ScannetPP2Dataset(ScanNetDataset):
def __init__(self, seq_name, root='data/scannetpp_dust3r_posed'):
super().__init__(seq_name, root)
self.image_size = self.get_image_size()
self.depth_size = self.get_depth_shape()
self.point_cloud_path = f'{self.root}/{seq_name}.ply'
def get_image_size(self):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('.')[0].split('_')[1]))
image_path = os.path.join(self.rgb_dir, image_list[0])
image = cv2.imread(image_path)
return image.shape[:2][::-1]
def get_depth_shape(self):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('.')[0].split('_')[1]))
depth_path = os.path.join(self.depth_dir, f"{image_list[0].split('.')[0]}.png")
depth = cv2.imread(depth_path, -1)
return depth.shape[:2][::-1]
def get_frame_list(self, stride):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('.')[0].split('_')[1]))
frame_id_list = [a.split('.')[0] for a in image_list]
return list(frame_id_list)
def get_segmentation(self, frame_id, align_with_depth=False):
segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png')
if not os.path.exists(segmentation_path):
assert False, f"Segmentation not found: {segmentation_path}"
segmentation = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED)
segmentation = cv2.resize(segmentation, self.depth_size, interpolation=cv2.INTER_NEAREST)
return segmentation
def get_label_id(self):
self.class_id = SCANNETPP84_IDS
self.class_label = SCANNETPP84_LABELS
self.label2id = {}
self.id2label = {}
for label, id in zip(self.class_label, self.class_id):
self.label2id[label] = id
self.id2label[id] = label
return self.label2id, self.id2label
def get_label_features(self):
label_features_dict = np.load(f'data/text_features/scannetpp84.npy', allow_pickle=True).item()
return label_features_dict
def get_depth(self, frame_id):
depth_path = os.path.join(self.depth_dir, str(frame_id) + '.png')
depth = cv2.imread(depth_path, -1)
depth = depth / self.depth_scale
depth = depth.astype(np.float32)
return depth
def get_intrinsics(self, frame_id):
intrinsic_path = f'{self.intrinsic_dir}/intrinsic_depth.txt'
intrinsics = np.loadtxt(intrinsic_path)
intrinisc_cam_parameters = o3d.camera.PinholeCameraIntrinsic()
intrinisc_cam_parameters.set_intrinsics(self.image_size[0], self.image_size[1], intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2])
return intrinisc_cam_parameters
class ScanNet18Dataset:
def __init__(self, seq_name, root='data/scannet') -> None:
self.seq_name = seq_name
self.root = os.path.join(root, 'processed', seq_name)
self.rgb_dir = f'{self.root}/color'
self.depth_dir = f'{self.root}/depth'
self.segmentation_dir = f'{self.root}/output/mask'
self.object_dict_dir = f'{self.root}/output/object'
self.point_cloud_path = f'{self.root}/{seq_name}.ply'
self.mesh_path = self.point_cloud_path
self.extrinsics_dir = f'{self.root}/pose'
self.intrinsic_dir = f'{self.root}/intrinsic'
self.depth_scale = 1000.0
self.image_size = self.get_image_size()
self.depth_size = self.get_depth_shape()
def get_frame_list(self, stride):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('.')[0]))
end = int(image_list[-1].split('.')[0]) + 1
frame_id_list = [a.split('.')[0] for a in image_list]
return list(frame_id_list)
def get_image_size(self):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('.')[0]))
image_path = os.path.join(self.rgb_dir, image_list[0])
image = cv2.imread(image_path)
return image.shape[:2][::-1]
def get_depth_shape(self):
image_list = os.listdir(self.rgb_dir)
image_list = sorted(image_list, key=lambda x: int(x.split('.')[0]))
depth_path = os.path.join(self.depth_dir, f"{image_list[0].split('.')[0]}.png")
depth = cv2.imread(depth_path, -1)
return depth.shape[:2][::-1]
def get_intrinsics(self, frame_id):
intrinsic_path = f'{self.intrinsic_dir}/intrinsic_depth.txt'
intrinsics = np.loadtxt(intrinsic_path)
intrinisc_cam_parameters = o3d.camera.PinholeCameraIntrinsic()
intrinisc_cam_parameters.set_intrinsics(self.image_size[0], self.image_size[1], intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2])
return intrinisc_cam_parameters
def get_extrinsic(self, frame_id):
pose_path = os.path.join(self.extrinsics_dir, str(frame_id) + '.txt')
pose = np.loadtxt(pose_path)
return pose
def get_depth(self, frame_id):
depth_path = os.path.join(self.depth_dir, str(frame_id) + '.png')
depth = cv2.imread(depth_path, -1)
depth = depth / self.depth_scale
depth = depth.astype(np.float32)
return depth
def get_rgb(self, frame_id, change_color=True):
rgb_path = os.path.join(self.rgb_dir, str(frame_id) + '.jpg')
rgb = cv2.imread(rgb_path)
if change_color:
rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
return rgb
def get_segmentation(self, frame_id, align_with_depth=False):
segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png')
if not os.path.exists(segmentation_path):
assert False, f"Segmentation not found: {segmentation_path}"
segmentation = cv2.imread(segmentation_path, cv2.IMREAD_UNCHANGED)
segmentation = cv2.resize(segmentation, self.depth_size, interpolation=cv2.INTER_NEAREST)
return segmentation
def get_frame_path(self, frame_id):
rgb_path = os.path.join(self.rgb_dir, str(frame_id) + '.jpg')
segmentation_path = os.path.join(self.segmentation_dir, f'{frame_id}.png')
return rgb_path, segmentation_path
def get_label_features(self):
label_features_dict = np.load(f'data/text_features/scannet18.npy', allow_pickle=True).item()
return label_features_dict
def get_scene_points(self):
mesh = o3d.io.read_point_cloud(self.point_cloud_path)
vertices = np.asarray(mesh.points)
return vertices
def get_label_id(self):
self.class_id = SCANNET18_IDS
self.class_label = SCANNET18_LABELS
self.label2id = {}
self.id2label = {}
for label, id in zip(self.class_label, self.class_id):
self.label2id[label] = id
self.id2label[id] = label
return self.label2id, self.id2label
class ScanNet20Dataset(ScanNet18Dataset):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.point_cloud_path = f'{self.root}/{self.seq_name}_vh_clean_2.ply'
def get_label_features(self):
label_features_dict = np.load(f'/home/jovyan/users/lemeshko/Indoor/MaskClustering/data/text_features/scannet20.npy', allow_pickle=True).item()
return label_features_dict
def get_label_id(self):
self.class_id = SCANNET20_IDS
self.class_label = SCANNET20_LABELS
self.label2id = {}
self.id2label = {}
for label, id in zip(self.class_label, self.class_id):
self.label2id[label] = id
self.id2label[id] = label
return self.label2id, self.id2label