zoo3d / MaskClustering /preprocess /scannet /prepare_single_gt.py
bulatko's picture
adding real MK
55e58d1
raw
history blame
2.96 kB
import os
import argparse
import json
import numpy as np
import pandas as pd
import sys
sys.path.append('/home/jovyan/users/bulat/workspace/3drec/MaskClustering')
print("prepare_single_gt sys.path", sys.path)
from evaluation.constants import SCANNET_IDS
def export_gt(filename, label_ids, instance_ids):
gt_data = label_ids * 1000 + instance_ids + 1
np.savetxt(filename, gt_data, fmt='%d')
def point_indices_from_group(seg_indices, group, labels_pd):
group_segments = np.array(group['segments'])
label = group['label']
# Map the category name to id
label_ids = labels_pd[labels_pd['raw_category'] == label]['id']
label_id = int(label_ids.iloc[0]) if len(label_ids) > 0 else 0
# Only store for the valid categories
if not label_id in SCANNET_IDS:
label_id = 0
# get points, where segment indices (points labelled with segment ids) are in the group segment list
point_IDs = np.where(np.isin(seg_indices, group_segments))
return point_IDs[0], label_id
def handle_single_scene(scene_path, output_path, labels_pd, scene_name):
segments_file = os.path.join(scene_path, f'{scene_name}_vh_clean_2.0.010000.segs.json')
aggregations_file = os.path.join(scene_path, f'{scene_name}.aggregation.json')
output_gt_file = os.path.join(output_path, f'{scene_name}.txt')
# Load segments file
with open(segments_file) as f:
segments = json.load(f)
seg_indices = np.array(segments['segIndices'])
# Load Aggregations file
with open(aggregations_file) as f:
aggregation = json.load(f)
seg_groups = np.array(aggregation['segGroups'])
# Generate new labels
labelled_pc = np.zeros((len(seg_indices), 1))
instance_ids = np.zeros((len(seg_indices), 1))
for group in seg_groups:
p_inds, label_id = point_indices_from_group(seg_indices, group, labels_pd)
labelled_pc[p_inds] = label_id
instance_ids[p_inds] = group['id'] + 1
labelled_pc = labelled_pc.astype(int)
instance_ids = instance_ids.astype(int)
export_gt(output_gt_file, labelled_pc, instance_ids)
print(f"Ground truth сохранен в {output_gt_file}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--scene_path', required=True, type=str, help='Путь к директории сцены')
parser.add_argument('--gt_dir', required=True, type=str, help='Директория для сохранения ground truth')
parser.add_argument('--label_map', required=True, type=str, help='Путь к файлу маппинга меток')
parser.add_argument('--scene_name', required=True, type=str, help='Имя сцены')
args = parser.parse_args()
# Load label map
labels_pd = pd.read_csv(args.label_map, sep='\t', header=0)
os.makedirs(args.gt_dir, exist_ok=True)
handle_single_scene(args.scene_path, args.gt_dir, labels_pd, args.scene_name)