zoo3d / MaskClustering /cluster_masks.py
bulatko's picture
adding real MK
55e58d1
raw
history blame
3.04 kB
import numpy as np
import os
import cv2
from pathlib import Path
import trimesh as tm
from sklearn.neighbors import KDTree
from tqdm import tqdm
from tqdm.contrib.concurrent import thread_map
from sklearn.cluster import DBSCAN
def load_scan(pcd_path):
pcd_data = np.fromfile(pcd_path, dtype=np.float32).reshape(-1, 6)[:, :3]
return pcd_data
def process_scene(data):
scene_id, exp_name = data
pred_path = Path(f"data/prediction/scannet/click_sam/{scene_id}.npz")
out_path = Path(f"data/prediction/scannet/{exp_name}/{scene_id}.npz")
base_path = Path(f"/home/jovyan/users/lemeshko/scripts/gsam_result/yolo/{scene_id}")
source_path = Path(f"/home/jovyan/users/kolodiazhnyi/data/scannet/posed_images/{scene_id}")
scan_path = Path(f"/home/jovyan/users/bulat/workspace/3drec/Indoor/OKNO/data/scannet200/points/{scene_id}.bin")
info_path = base_path / "infos.npy"
# if out_path.exists():
# return
vertices = load_scan(scan_path)
info_data = np.load(info_path, allow_pickle=True).item()
base_data = np.load(pred_path, allow_pickle=True)
total_points_masks = base_data['pred_masks'].T
for i, mask in enumerate(total_points_masks):
mask = mask.astype(bool)
points = vertices[mask]
db = DBSCAN(eps=0.3, min_samples=10)
if len(points) == 0:
continue
labels = db.fit_predict(points)
# labels = db.labels_
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
if (labels == -1).all():
continue
biggest_cluster_ind = np.argmax(np.unique(labels[labels != -1], return_counts=True)[1])
res_mask = (labels == biggest_cluster_ind) & (labels != -1)
# print(f"{labels.shape} -> {res_mask.sum()}")
new_mask = np.zeros_like(mask)
new_mask[mask] = res_mask
total_points_masks[i] = new_mask
new_data = {
k: v for k, v in base_data.items()
}
new_data['pred_masks'] = total_points_masks.T
out_path.parent.mkdir(parents=True, exist_ok=True)
# vs = []
# cs = []
# for i in range(new_data['pred_masks'].shape[1]):
# os.makedirs(f"pred_masks", exist_ok=True)
# v = vertices[new_data['pred_masks'][:, i]]
# c = np.random.rand(3)
# c = np.repeat(c[np.newaxis, :], len(v), axis=0)
# vs.append(v)
# cs.append(c)
# tm.PointCloud(np.concatenate(vs, axis=0), colors=np.concatenate(cs, axis=0)).export(f"pred_masks/{scene_id}_mask.ply")
print("uniques", np.unique(new_data['pred_masks'].sum(1)), [[k, v.shape] for k, v in new_data.items()])
np.savez(out_path, **new_data)
if __name__ == "__main__":
exp_name = "cluster_filtering_click_sam"
# scenes = np.loadtxt("/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/splits/scannet.txt", dtype=str)
scenes = ["scene0011_00"]
data = [(scene, exp_name) for scene in scenes]
total_points_masks = thread_map(process_scene, data, chunksize=20)