| import numpy as np |
| import os |
| import pickle |
| from tqdm.auto import tqdm |
| from collections import defaultdict |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from copy import deepcopy |
| import torch |
|
|
| class PredBBoxDistrPP: |
| |
| SCANNET_IDS = [4, 3, 6, 5, 9, 7, 8, 10, 12, 11, 14, 13, 23, 17, 18, 24, 25, 27, 28, 47, 88, 35, 36, 42, 45, 58, 49, 54, 56, 59, 60, 63, 67, 68, 102, 71, 72, 74, 81, 83, 90, 96, 122, 416, 106, 111, 117, 126, 129, 132, 155, 166, 173, 188, 300, 199, 204, 214, 219, 253, 299, 265, 273, 352, 295, 296, 301, 305, 312, 342, 358, 364, 368, 387, 395, 396, 403, 405, 414, 443, 469, 515, 744, 1157] |
| |
| SCANNET_LABELS = ['table', 'door', 'ceiling lamp', 'cabinet', 'blinds', 'curtain', 'chair', 'storage cabinet', 'office chair', 'bookshelf', 'whiteboard', 'window', 'box', |
| 'monitor', 'shelf', 'heater', 'kitchen cabinet', 'sofa', 'bed', 'trash can', 'book', 'plant', 'blanket', 'tv', 'computer tower', 'refrigerator', 'jacket', |
| 'sink', 'bag', 'picture', 'pillow', 'towel', 'suitcase', 'backpack', 'crate', 'keyboard', 'rack', 'toilet', 'printer', 'poster', 'painting', 'microwave', 'shoes', |
| 'socket', 'bottle', 'bucket', 'cushion', 'basket', 'shoe rack', 'telephone', 'file folder', 'laptop', 'plant pot', 'exhaust fan', 'cup', 'coat hanger', 'light switch', |
| 'speaker', 'table lamp', 'kettle', 'smoke detector', 'container', 'power strip', 'slippers', 'paper bag', 'mouse', 'cutting board', 'toilet paper', 'paper towel', |
| 'pot', 'clock', 'pan', 'tap', 'jar', 'soap dispenser', 'binder', 'bowl', 'tissue box', 'whiteboard eraser', 'toilet brush', 'spray bottle', 'headphones', 'stapler', 'marker'] |
| |
| ID2LABEL = dict(zip(SCANNET_IDS, SCANNET_LABELS)) |
| |
| LABEL2ID = dict(zip(SCANNET_LABELS, SCANNET_IDS)) |
| |
| INV_SCANNET_IDS = {idx: i for i, idx in enumerate(SCANNET_IDS)} |
| |
| @staticmethod |
| def _normalize_scene_id(value): |
| base = os.path.basename(value) |
| if base.endswith('.bin'): |
| base = base[:-4] |
| else: |
| base = os.path.splitext(base)[0] |
| return base |
| |
| def __init__(self, path, bins_path, gt_pkl_path): |
| self.path = path |
| self.bins_path = bins_path |
| self.gt_pkl_path = gt_pkl_path |
| |
| self.get_scenes() |
| self.class_scores = defaultdict(list) |
| for scene_id in self.scene_ids: |
| self.get_scene_inst(scene_id) |
| self.sorted_names = sorted(self.SCANNET_LABELS, key=lambda x: self.gt_sample_counts[x]) |
| |
| def load_pkl_scene_by_id(self, scene_id): |
| """ |
| Вернуть описание сцены из PKL по scene_id (без расширения). |
| Поддерживает как id вида "sceneXXXX_YY", так и пути/имена с .bin. |
| """ |
| target_id = self._normalize_scene_id(scene_id) |
| with open(self.gt_pkl_path, 'rb') as file: |
| data = pickle.load(file) |
| for scene in data.get('data_list', []): |
| lidar_path = scene.get('lidar_points', {}).get('lidar_path') |
| if not lidar_path: |
| continue |
| candidate_id = self._normalize_scene_id(lidar_path) |
| if candidate_id == target_id: |
| return scene |
| return None |
| |
| def get_scenes(self): |
| self.scene_ids = [] |
| self.gt_sample_counts = defaultdict(int) |
| with open('/home/jovyan/users/lemeshko/TMP/my_pkls/scannetpp_infos_84class_train.pkl', 'rb') as file: |
| data = pickle.load(file) |
| picked_scenes = set(map(lambda x: x[:-4], os.listdir(self.path))) |
| for scene in data['data_list']: |
| scene_name = scene['lidar_points']['lidar_path'][:-4] |
| if scene_name not in picked_scenes: |
| continue |
| self.scene_ids.append(scene_name) |
| for instance in scene['instances']: |
| inst_id = instance['bbox_label_3d'] |
| self.gt_sample_counts[self.SCANNET_LABELS[inst_id]] += 1 |
| |
| def get_scene_inst(self, scene_id): |
| cls_path = f'{self.path}/{scene_id}.npz' |
| cls_data = np.load(cls_path, allow_pickle=True) |
| for class_id, class_score in zip(cls_data['pred_classes'], cls_data['pred_score']): |
| self.class_scores[self.ID2LABEL[class_id]].append(class_score) |
| |
| def plot_class_distr(self, class_name='all'): |
| """ |
| Построить распределение оценок для конкретного класса или всех классов вместе |
| |
| Parameters: |
| class_name: str or list - название класса, 'all' для всех классов, |
| или список названий классов |
| """ |
| if class_name == 'all': |
| |
| all_scores = [] |
| for scores in self.class_scores.values(): |
| all_scores.extend(scores) |
| scores = all_scores |
| display_name = 'All Classes' |
| elif isinstance(class_name, list): |
| |
| selected_scores = [] |
| for cls in class_name: |
| if cls in self.class_scores: |
| selected_scores.extend(self.class_scores[cls]) |
| else: |
| print(f"Warning: Class '{cls}' not found in class_scores") |
| scores = selected_scores |
| display_name = f'Classes: {", ".join(class_name[:3])}{"..." if len(class_name) > 3 else ""}' |
| else: |
| |
| if class_name not in self.class_scores: |
| print(f"Class '{class_name}' not found in class_scores") |
| print(f"Available classes: {list(self.class_scores.keys())[:10]}...") |
| return |
| scores = self.class_scores[class_name] |
| display_name = class_name |
| |
| if not scores: |
| print(f"No scores available for: {display_name}") |
| return |
| |
| |
| fig, ax = plt.subplots(figsize=(12, 8)) |
| |
| |
| sns.histplot(scores, bins=30, kde=True, ax=ax, color='skyblue', |
| stat='density', alpha=0.7) |
| ax.set_title(f'Distribution of scores for {display_name}', fontsize=14, fontweight='bold') |
| ax.set_xlabel('Score', fontsize=12) |
| ax.set_ylabel('Density', fontsize=12) |
| ax.grid(True, alpha=0.3) |
| |
| |
| |
| |
| |
| |
| |
| median_score = np.median(scores) |
| ax.axvline(median_score, color='green', linestyle='--', linewidth=2, |
| label=f'Median: {median_score:.3f}') |
| ax.axvline(np.percentile(scores, 32.45), color='red', linestyle='-', linewidth=2, |
| label=f'Size bound: {np.percentile(scores, 32.45):.3f}') |
| |
| ax.legend() |
| |
| |
| if class_name == 'all': |
| class_info = f"Total classes: {len(self.class_scores)}" |
| elif isinstance(class_name, list): |
| class_info = f"Selected classes: {len(class_name)}" |
| else: |
| class_info = f"Class: {class_name}" |
| |
| stats_text = f"""Statistics for {display_name}: |
| {class_info} |
| Total instances: {len(scores):,} |
| Mean: {np.mean(scores):.3f} |
| Median: {np.median(scores):.3f} |
| Std: {np.std(scores):.3f} |
| Min: {np.min(scores):.3f} |
| Max: {np.max(scores):.3f} |
| Q1: {np.percentile(scores, 25):.3f} |
| Q : {np.percentile(scores, 32.45):.3f} |
| Q3: {np.percentile(scores, 75):.3f}""" |
| |
| |
| props = dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8) |
| ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, fontfamily='monospace', |
| verticalalignment='top', bbox=props, fontsize=10) |
| |
| plt.tight_layout() |
| plt.show() |
| |
| |
| print(stats_text) |
| |
| return scores |
|
|
| |
| def plot_multiple_classes(self, class_names: list): |
| """ |
| Сравнить распределения нескольких классов на одном графике |
| """ |
| fig, ax = plt.subplots(figsize=(12, 8)) |
| |
| colors = ['skyblue', 'lightcoral', 'lightgreen', 'gold', 'lightpink'] |
| |
| for i, cls in enumerate(class_names): |
| if cls not in self.class_scores: |
| print(f"Warning: Class '{cls}' not found, skipping") |
| continue |
| |
| scores = self.class_scores[cls] |
| if scores: |
| sns.kdeplot(scores, ax=ax, label=cls, color=colors[i % len(colors)], |
| linewidth=2, alpha=0.8) |
| |
| ax.set_title('Score Distribution Comparison', fontsize=14, fontweight='bold') |
| ax.set_xlabel('Score', fontsize=12) |
| ax.set_ylabel('Density', fontsize=12) |
| ax.grid(True, alpha=0.3) |
| ax.legend() |
| |
| plt.tight_layout() |
| plt.show() |
| |
| def get_class_lowerbound(self, class_name='all', percentile=32.45): |
| if class_name == 'all': |
| |
| all_scores = [] |
| for scores in self.class_scores.values(): |
| all_scores.extend(scores) |
| scores = all_scores |
| elif isinstance(class_name, list): |
| selected_scores = [] |
| for cls in class_name: |
| if cls in self.class_scores: |
| selected_scores.extend(self.class_scores[cls]) |
| else: |
| print(f"Warning: Class '{cls}' not found in class_scores") |
| scores = selected_scores |
| else: |
| |
| if class_name not in self.class_scores: |
| print(f"Class '{class_name}' not found in class_scores") |
| print(f"Available classes: {list(self.class_scores.keys())[:10]}...") |
| return |
| scores = self.class_scores[class_name] |
| |
| return np.percentile(scores, percentile) |
| |
| def get_bboxes_by_masks(self, masks, points): |
| boxes = [] |
| for mask in masks: |
| object_points = points[mask][:, :3] |
| |
| |
| xyz_min = object_points.quantile(0.01, dim=0) |
| xyz_max = object_points.quantile(0.99, dim=0) |
| center = (xyz_max + xyz_min) / 2 |
| size = xyz_max - xyz_min |
| box = torch.cat((center, size)) |
| boxes.append(box) |
| assert len(boxes) != 0, "Why 0 masks in scene?" |
| boxes = torch.stack(boxes) |
| return boxes |
| |
| def get_scene_instances(self, scene_name, score_bounds, class_agnostic): |
| instances = [] |
| points_path = f'{self.bins_path}/{scene_name}.bin' |
| points = torch.from_numpy(np.fromfile(points_path, dtype=np.float32).reshape((-1, 6))) |
| |
| gt_scene = self.load_pkl_scene_by_id(scene_name) |
| if gt_scene is not None and 'axis_align_matrix' in gt_scene: |
| a = torch.as_tensor(np.array(gt_scene['axis_align_matrix'], dtype=np.float32)) |
| R = a[:3, :3] |
| t = a[:3, 3] |
| xyz = points[:, :3] |
| points[:, :3] = xyz @ R.T + t |
| cls_path = f'{self.path}/{scene_name}.npz' |
| cls_data = np.load(cls_path, allow_pickle=True) |
| pred_masks = torch.from_numpy(cls_data['pred_masks']).T |
| pred_classes = cls_data['pred_classes'] |
| pred_scores = cls_data['pred_score'] |
| boxes = self.get_bboxes_by_masks(pred_masks, points) |
| for box, pred_class, pred_score in zip(boxes, pred_classes, pred_scores): |
| if pred_score > score_bounds.get(pred_class, 0): |
| write_class = 0 if class_agnostic else self.INV_SCANNET_IDS[pred_class] |
| instances.append({'bbox_3d': box.numpy().tolist(), 'bbox_label_3d': write_class}) |
| return instances |
| |
| def filter_instances_topk_by_gt(self, scene_name, class_agnostic=True): |
| """ |
| Фильтрует предсказанные инстансы по top-K, где K = количество GT-инстансов. |
| Шаги: |
| 1) Берем все маски, переводим в 3D bbox-ы |
| 2) Сортируем по убыванию предикт-скор |
| 3) Оставляем top-K, где K равно числу GT-инстансов в PKL |
| Возвращает список инстансов в формате mmdet3d (bbox_3d, bbox_label_3d). |
| """ |
| scene_id = self._normalize_scene_id(scene_name) |
| gt_scene = self.load_pkl_scene_by_id(scene_name) |
| gt_count = len(gt_scene.get('instances', [])) if gt_scene else 0 |
| if gt_count <= 0: |
| return [] |
|
|
| points_path = f'{self.bins_path}/{scene_id}.bin' |
| points = torch.from_numpy(np.fromfile(points_path, dtype=np.float32).reshape((-1, 6))) |
| |
| if gt_scene is not None and 'axis_align_matrix' in gt_scene: |
| a = torch.as_tensor(np.array(gt_scene['axis_align_matrix'], dtype=np.float32)) |
| print(a) |
| R = a[:3, :3] |
| t = a[:3, 3] |
| xyz = points[:, :3] |
| points[:, :3] = xyz @ R.T + t |
| cls_path = f'{self.path}/{scene_id}.npz' |
| cls_data = np.load(cls_path, allow_pickle=True) |
| pred_masks = torch.from_numpy(cls_data['pred_masks']).T |
| pred_classes = cls_data['pred_classes'] |
| pred_scores = cls_data['pred_score'] |
|
|
| if len(pred_scores) == 0: |
| return [] |
|
|
| topk = int(min(gt_count, len(pred_scores))) |
| np_topk_indices = np.argsort(-pred_scores)[:topk] |
|
|
| |
| torch_topk_indices = torch.as_tensor(np_topk_indices, dtype=torch.long) |
| selected_masks = pred_masks[torch_topk_indices] |
| boxes = self.get_bboxes_by_masks(selected_masks, points) |
| selected_classes = pred_classes[np_topk_indices] |
|
|
| instances = [] |
| for box, pred_class in zip(boxes, selected_classes): |
| write_class = 0 if class_agnostic else self.INV_SCANNET_IDS[pred_class] |
| instances.append({'bbox_3d': box.numpy().tolist(), 'bbox_label_3d': write_class}) |
| return instances |
| |
| |
| def make_pkl(self, percentiles, pkl_path, class_agnostic=True): |
| |
| score_bounds = {} |
| for classes, percentile in percentiles: |
| score_bound = self.get_class_lowerbound(classes, percentile) |
| if classes == 'all': |
| classes = self.sorted_names |
| if isinstance(classes, list): |
| for class_ in classes: |
| score_bounds[self.LABEL2ID[class_]] = score_bound |
| else: |
| score_bounds[self.LABEL2ID[classes]] = score_bound |
| print(score_bounds) |
| new_data = {} |
| with open(self.gt_pkl_path, 'rb') as file: |
| data = pickle.load(file) |
| new_data['metainfo'] = data['metainfo'] |
| data_list = [] |
| picked_scenes = set(map(lambda x: x[:-4], os.listdir(self.path))) |
| for scene in tqdm(data['data_list']): |
| scene_name = scene['lidar_points']['lidar_path'][:-4] |
| if scene_name not in picked_scenes: |
| continue |
| tmp_scene = deepcopy(scene) |
| instances = self.get_scene_instances(scene_name, score_bounds, class_agnostic) |
| |
| tmp_scene['instances'] = instances |
| data_list.append(tmp_scene) |
| new_data['data_list'] = data_list |
| with open(pkl_path, 'wb') as file: |
| pickle.dump(new_data, file) |
| |
| @property |
| def scores(self): |
| return self.class_scores |
| |
|
|
| if __name__ == "__main__": |
| pred_path = \ |
| "/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/data/prediction/scannetpp_dust3r_posed" |
| bins_path = \ |
| "/home/jovyan/users/bulat/workspace/3drec/Indoor/OKNO/data/scannetpp/bins/points_dust3r_posed" |
| out_pkl_path = \ |
| "/home/jovyan/users/bulat/workspace/3drec/Indoor/OKNO/data/scannetpp/bins/scannetpp84_dust3r_posed_train10.pkl" |
| gt_pkl_path = \ |
| "/home/jovyan/users/lemeshko/TMP/my_pkls/scannetpp_infos_84class_train.pkl" |
|
|
| distr = PredBBoxDistrPP(pred_path, bins_path, gt_pkl_path) |
|
|
| with open(gt_pkl_path, 'rb') as file: |
| gt_data = pickle.load(file) |
|
|
| new_data = {"metainfo": gt_data["metainfo"]} |
| data_list = [] |
|
|
| picked_scenes = set(map(lambda x: x[:-4], os.listdir(distr.path))) |
| for scene in tqdm(gt_data['data_list']): |
| scene_name = distr._normalize_scene_id(scene['lidar_points']['lidar_path']) |
| if scene_name not in picked_scenes: |
| continue |
| tmp_scene = deepcopy(scene) |
| instances = distr.filter_instances_topk_by_gt(scene_name, class_agnostic=False) |
| tmp_scene['instances'] = instances |
| data_list.append(tmp_scene) |
|
|
| new_data['data_list'] = data_list |
| with open(out_pkl_path, 'wb') as f: |
| pickle.dump(new_data, f) |
|
|