|
|
import numpy as np |
|
|
import os |
|
|
import pickle |
|
|
from tqdm.auto import tqdm |
|
|
from collections import defaultdict |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from copy import deepcopy |
|
|
import torch |
|
|
|
|
|
class PredBBoxDistrPP: |
|
|
|
|
|
SCANNET_IDS = [4, 3, 6, 5, 9, 7, 8, 10, 12, 11, 14, 13, 23, 17, 18, 24, 25, 27, 28, 47, 88, 35, 36, 42, 45, 58, 49, 54, 56, 59, 60, 63, 67, 68, 102, 71, 72, 74, 81, 83, 90, 96, 122, 416, 106, 111, 117, 126, 129, 132, 155, 166, 173, 188, 300, 199, 204, 214, 219, 253, 299, 265, 273, 352, 295, 296, 301, 305, 312, 342, 358, 364, 368, 387, 395, 396, 403, 405, 414, 443, 469, 515, 744, 1157] |
|
|
|
|
|
SCANNET_LABELS = ['table', 'door', 'ceiling lamp', 'cabinet', 'blinds', 'curtain', 'chair', 'storage cabinet', 'office chair', 'bookshelf', 'whiteboard', 'window', 'box', |
|
|
'monitor', 'shelf', 'heater', 'kitchen cabinet', 'sofa', 'bed', 'trash can', 'book', 'plant', 'blanket', 'tv', 'computer tower', 'refrigerator', 'jacket', |
|
|
'sink', 'bag', 'picture', 'pillow', 'towel', 'suitcase', 'backpack', 'crate', 'keyboard', 'rack', 'toilet', 'printer', 'poster', 'painting', 'microwave', 'shoes', |
|
|
'socket', 'bottle', 'bucket', 'cushion', 'basket', 'shoe rack', 'telephone', 'file folder', 'laptop', 'plant pot', 'exhaust fan', 'cup', 'coat hanger', 'light switch', |
|
|
'speaker', 'table lamp', 'kettle', 'smoke detector', 'container', 'power strip', 'slippers', 'paper bag', 'mouse', 'cutting board', 'toilet paper', 'paper towel', |
|
|
'pot', 'clock', 'pan', 'tap', 'jar', 'soap dispenser', 'binder', 'bowl', 'tissue box', 'whiteboard eraser', 'toilet brush', 'spray bottle', 'headphones', 'stapler', 'marker'] |
|
|
|
|
|
ID2LABEL = dict(zip(SCANNET_IDS, SCANNET_LABELS)) |
|
|
|
|
|
LABEL2ID = dict(zip(SCANNET_LABELS, SCANNET_IDS)) |
|
|
|
|
|
INV_SCANNET_IDS = {idx: i for i, idx in enumerate(SCANNET_IDS)} |
|
|
|
|
|
@staticmethod |
|
|
def _normalize_scene_id(value): |
|
|
base = os.path.basename(value) |
|
|
if base.endswith('.bin'): |
|
|
base = base[:-4] |
|
|
else: |
|
|
base = os.path.splitext(base)[0] |
|
|
return base |
|
|
|
|
|
def __init__(self, path, bins_path, gt_pkl_path): |
|
|
self.path = path |
|
|
self.bins_path = bins_path |
|
|
self.gt_pkl_path = gt_pkl_path |
|
|
|
|
|
self.get_scenes() |
|
|
self.class_scores = defaultdict(list) |
|
|
for scene_id in self.scene_ids: |
|
|
self.get_scene_inst(scene_id) |
|
|
self.sorted_names = sorted(self.SCANNET_LABELS, key=lambda x: self.gt_sample_counts[x]) |
|
|
|
|
|
def load_pkl_scene_by_id(self, scene_id): |
|
|
""" |
|
|
Вернуть описание сцены из PKL по scene_id (без расширения). |
|
|
Поддерживает как id вида "sceneXXXX_YY", так и пути/имена с .bin. |
|
|
""" |
|
|
target_id = self._normalize_scene_id(scene_id) |
|
|
with open(self.gt_pkl_path, 'rb') as file: |
|
|
data = pickle.load(file) |
|
|
for scene in data.get('data_list', []): |
|
|
lidar_path = scene.get('lidar_points', {}).get('lidar_path') |
|
|
if not lidar_path: |
|
|
continue |
|
|
candidate_id = self._normalize_scene_id(lidar_path) |
|
|
if candidate_id == target_id: |
|
|
return scene |
|
|
return None |
|
|
|
|
|
def get_scenes(self): |
|
|
self.scene_ids = [] |
|
|
self.gt_sample_counts = defaultdict(int) |
|
|
with open('/home/jovyan/users/lemeshko/TMP/my_pkls/scannetpp_infos_84class_train.pkl', 'rb') as file: |
|
|
data = pickle.load(file) |
|
|
picked_scenes = set(map(lambda x: x[:-4], os.listdir(self.path))) |
|
|
for scene in data['data_list']: |
|
|
scene_name = scene['lidar_points']['lidar_path'][:-4] |
|
|
if scene_name not in picked_scenes: |
|
|
continue |
|
|
self.scene_ids.append(scene_name) |
|
|
for instance in scene['instances']: |
|
|
inst_id = instance['bbox_label_3d'] |
|
|
self.gt_sample_counts[self.SCANNET_LABELS[inst_id]] += 1 |
|
|
|
|
|
def get_scene_inst(self, scene_id): |
|
|
cls_path = f'{self.path}/{scene_id}.npz' |
|
|
cls_data = np.load(cls_path, allow_pickle=True) |
|
|
for class_id, class_score in zip(cls_data['pred_classes'], cls_data['pred_score']): |
|
|
self.class_scores[self.ID2LABEL[class_id]].append(class_score) |
|
|
|
|
|
def plot_class_distr(self, class_name='all'): |
|
|
""" |
|
|
Построить распределение оценок для конкретного класса или всех классов вместе |
|
|
|
|
|
Parameters: |
|
|
class_name: str or list - название класса, 'all' для всех классов, |
|
|
или список названий классов |
|
|
""" |
|
|
if class_name == 'all': |
|
|
|
|
|
all_scores = [] |
|
|
for scores in self.class_scores.values(): |
|
|
all_scores.extend(scores) |
|
|
scores = all_scores |
|
|
display_name = 'All Classes' |
|
|
elif isinstance(class_name, list): |
|
|
|
|
|
selected_scores = [] |
|
|
for cls in class_name: |
|
|
if cls in self.class_scores: |
|
|
selected_scores.extend(self.class_scores[cls]) |
|
|
else: |
|
|
print(f"Warning: Class '{cls}' not found in class_scores") |
|
|
scores = selected_scores |
|
|
display_name = f'Classes: {", ".join(class_name[:3])}{"..." if len(class_name) > 3 else ""}' |
|
|
else: |
|
|
|
|
|
if class_name not in self.class_scores: |
|
|
print(f"Class '{class_name}' not found in class_scores") |
|
|
print(f"Available classes: {list(self.class_scores.keys())[:10]}...") |
|
|
return |
|
|
scores = self.class_scores[class_name] |
|
|
display_name = class_name |
|
|
|
|
|
if not scores: |
|
|
print(f"No scores available for: {display_name}") |
|
|
return |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
|
|
|
|
|
|
|
sns.histplot(scores, bins=30, kde=True, ax=ax, color='skyblue', |
|
|
stat='density', alpha=0.7) |
|
|
ax.set_title(f'Distribution of scores for {display_name}', fontsize=14, fontweight='bold') |
|
|
ax.set_xlabel('Score', fontsize=12) |
|
|
ax.set_ylabel('Density', fontsize=12) |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
median_score = np.median(scores) |
|
|
ax.axvline(median_score, color='green', linestyle='--', linewidth=2, |
|
|
label=f'Median: {median_score:.3f}') |
|
|
ax.axvline(np.percentile(scores, 32.45), color='red', linestyle='-', linewidth=2, |
|
|
label=f'Size bound: {np.percentile(scores, 32.45):.3f}') |
|
|
|
|
|
ax.legend() |
|
|
|
|
|
|
|
|
if class_name == 'all': |
|
|
class_info = f"Total classes: {len(self.class_scores)}" |
|
|
elif isinstance(class_name, list): |
|
|
class_info = f"Selected classes: {len(class_name)}" |
|
|
else: |
|
|
class_info = f"Class: {class_name}" |
|
|
|
|
|
stats_text = f"""Statistics for {display_name}: |
|
|
{class_info} |
|
|
Total instances: {len(scores):,} |
|
|
Mean: {np.mean(scores):.3f} |
|
|
Median: {np.median(scores):.3f} |
|
|
Std: {np.std(scores):.3f} |
|
|
Min: {np.min(scores):.3f} |
|
|
Max: {np.max(scores):.3f} |
|
|
Q1: {np.percentile(scores, 25):.3f} |
|
|
Q : {np.percentile(scores, 32.45):.3f} |
|
|
Q3: {np.percentile(scores, 75):.3f}""" |
|
|
|
|
|
|
|
|
props = dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8) |
|
|
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, fontfamily='monospace', |
|
|
verticalalignment='top', bbox=props, fontsize=10) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
|
|
|
print(stats_text) |
|
|
|
|
|
return scores |
|
|
|
|
|
|
|
|
def plot_multiple_classes(self, class_names: list): |
|
|
""" |
|
|
Сравнить распределения нескольких классов на одном графике |
|
|
""" |
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
|
|
|
|
colors = ['skyblue', 'lightcoral', 'lightgreen', 'gold', 'lightpink'] |
|
|
|
|
|
for i, cls in enumerate(class_names): |
|
|
if cls not in self.class_scores: |
|
|
print(f"Warning: Class '{cls}' not found, skipping") |
|
|
continue |
|
|
|
|
|
scores = self.class_scores[cls] |
|
|
if scores: |
|
|
sns.kdeplot(scores, ax=ax, label=cls, color=colors[i % len(colors)], |
|
|
linewidth=2, alpha=0.8) |
|
|
|
|
|
ax.set_title('Score Distribution Comparison', fontsize=14, fontweight='bold') |
|
|
ax.set_xlabel('Score', fontsize=12) |
|
|
ax.set_ylabel('Density', fontsize=12) |
|
|
ax.grid(True, alpha=0.3) |
|
|
ax.legend() |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
def get_class_lowerbound(self, class_name='all', percentile=32.45): |
|
|
if class_name == 'all': |
|
|
|
|
|
all_scores = [] |
|
|
for scores in self.class_scores.values(): |
|
|
all_scores.extend(scores) |
|
|
scores = all_scores |
|
|
elif isinstance(class_name, list): |
|
|
selected_scores = [] |
|
|
for cls in class_name: |
|
|
if cls in self.class_scores: |
|
|
selected_scores.extend(self.class_scores[cls]) |
|
|
else: |
|
|
print(f"Warning: Class '{cls}' not found in class_scores") |
|
|
scores = selected_scores |
|
|
else: |
|
|
|
|
|
if class_name not in self.class_scores: |
|
|
print(f"Class '{class_name}' not found in class_scores") |
|
|
print(f"Available classes: {list(self.class_scores.keys())[:10]}...") |
|
|
return |
|
|
scores = self.class_scores[class_name] |
|
|
|
|
|
return np.percentile(scores, percentile) |
|
|
|
|
|
def get_bboxes_by_masks(self, masks, points): |
|
|
boxes = [] |
|
|
for mask in masks: |
|
|
object_points = points[mask][:, :3] |
|
|
|
|
|
|
|
|
xyz_min = object_points.quantile(0.01, dim=0) |
|
|
xyz_max = object_points.quantile(0.99, dim=0) |
|
|
center = (xyz_max + xyz_min) / 2 |
|
|
size = xyz_max - xyz_min |
|
|
box = torch.cat((center, size)) |
|
|
boxes.append(box) |
|
|
assert len(boxes) != 0, "Why 0 masks in scene?" |
|
|
boxes = torch.stack(boxes) |
|
|
return boxes |
|
|
|
|
|
def get_scene_instances(self, scene_name, score_bounds, class_agnostic): |
|
|
instances = [] |
|
|
points_path = f'{self.bins_path}/{scene_name}.bin' |
|
|
points = torch.from_numpy(np.fromfile(points_path, dtype=np.float32).reshape((-1, 6))) |
|
|
|
|
|
gt_scene = self.load_pkl_scene_by_id(scene_name) |
|
|
if gt_scene is not None and 'axis_align_matrix' in gt_scene: |
|
|
a = torch.as_tensor(np.array(gt_scene['axis_align_matrix'], dtype=np.float32)) |
|
|
R = a[:3, :3] |
|
|
t = a[:3, 3] |
|
|
xyz = points[:, :3] |
|
|
points[:, :3] = xyz @ R.T + t |
|
|
cls_path = f'{self.path}/{scene_name}.npz' |
|
|
cls_data = np.load(cls_path, allow_pickle=True) |
|
|
pred_masks = torch.from_numpy(cls_data['pred_masks']).T |
|
|
pred_classes = cls_data['pred_classes'] |
|
|
pred_scores = cls_data['pred_score'] |
|
|
boxes = self.get_bboxes_by_masks(pred_masks, points) |
|
|
for box, pred_class, pred_score in zip(boxes, pred_classes, pred_scores): |
|
|
if pred_score > score_bounds.get(pred_class, 0): |
|
|
write_class = 0 if class_agnostic else self.INV_SCANNET_IDS[pred_class] |
|
|
instances.append({'bbox_3d': box.numpy().tolist(), 'bbox_label_3d': write_class}) |
|
|
return instances |
|
|
|
|
|
def filter_instances_topk_by_gt(self, scene_name, class_agnostic=True): |
|
|
""" |
|
|
Фильтрует предсказанные инстансы по top-K, где K = количество GT-инстансов. |
|
|
Шаги: |
|
|
1) Берем все маски, переводим в 3D bbox-ы |
|
|
2) Сортируем по убыванию предикт-скор |
|
|
3) Оставляем top-K, где K равно числу GT-инстансов в PKL |
|
|
Возвращает список инстансов в формате mmdet3d (bbox_3d, bbox_label_3d). |
|
|
""" |
|
|
scene_id = self._normalize_scene_id(scene_name) |
|
|
gt_scene = self.load_pkl_scene_by_id(scene_name) |
|
|
gt_count = len(gt_scene.get('instances', [])) if gt_scene else 0 |
|
|
if gt_count <= 0: |
|
|
return [] |
|
|
|
|
|
points_path = f'{self.bins_path}/{scene_id}.bin' |
|
|
points = torch.from_numpy(np.fromfile(points_path, dtype=np.float32).reshape((-1, 6))) |
|
|
|
|
|
if gt_scene is not None and 'axis_align_matrix' in gt_scene: |
|
|
a = torch.as_tensor(np.array(gt_scene['axis_align_matrix'], dtype=np.float32)) |
|
|
print(a) |
|
|
R = a[:3, :3] |
|
|
t = a[:3, 3] |
|
|
xyz = points[:, :3] |
|
|
points[:, :3] = xyz @ R.T + t |
|
|
cls_path = f'{self.path}/{scene_id}.npz' |
|
|
cls_data = np.load(cls_path, allow_pickle=True) |
|
|
pred_masks = torch.from_numpy(cls_data['pred_masks']).T |
|
|
pred_classes = cls_data['pred_classes'] |
|
|
pred_scores = cls_data['pred_score'] |
|
|
|
|
|
if len(pred_scores) == 0: |
|
|
return [] |
|
|
|
|
|
topk = int(min(gt_count, len(pred_scores))) |
|
|
np_topk_indices = np.argsort(-pred_scores)[:topk] |
|
|
|
|
|
|
|
|
torch_topk_indices = torch.as_tensor(np_topk_indices, dtype=torch.long) |
|
|
selected_masks = pred_masks[torch_topk_indices] |
|
|
boxes = self.get_bboxes_by_masks(selected_masks, points) |
|
|
selected_classes = pred_classes[np_topk_indices] |
|
|
|
|
|
instances = [] |
|
|
for box, pred_class in zip(boxes, selected_classes): |
|
|
write_class = 0 if class_agnostic else self.INV_SCANNET_IDS[pred_class] |
|
|
instances.append({'bbox_3d': box.numpy().tolist(), 'bbox_label_3d': write_class}) |
|
|
return instances |
|
|
|
|
|
|
|
|
def make_pkl(self, percentiles, pkl_path, class_agnostic=True): |
|
|
|
|
|
score_bounds = {} |
|
|
for classes, percentile in percentiles: |
|
|
score_bound = self.get_class_lowerbound(classes, percentile) |
|
|
if classes == 'all': |
|
|
classes = self.sorted_names |
|
|
if isinstance(classes, list): |
|
|
for class_ in classes: |
|
|
score_bounds[self.LABEL2ID[class_]] = score_bound |
|
|
else: |
|
|
score_bounds[self.LABEL2ID[classes]] = score_bound |
|
|
print(score_bounds) |
|
|
new_data = {} |
|
|
with open(self.gt_pkl_path, 'rb') as file: |
|
|
data = pickle.load(file) |
|
|
new_data['metainfo'] = data['metainfo'] |
|
|
data_list = [] |
|
|
picked_scenes = set(map(lambda x: x[:-4], os.listdir(self.path))) |
|
|
for scene in tqdm(data['data_list']): |
|
|
scene_name = scene['lidar_points']['lidar_path'][:-4] |
|
|
if scene_name not in picked_scenes: |
|
|
continue |
|
|
tmp_scene = deepcopy(scene) |
|
|
instances = self.get_scene_instances(scene_name, score_bounds, class_agnostic) |
|
|
|
|
|
tmp_scene['instances'] = instances |
|
|
data_list.append(tmp_scene) |
|
|
new_data['data_list'] = data_list |
|
|
with open(pkl_path, 'wb') as file: |
|
|
pickle.dump(new_data, file) |
|
|
|
|
|
@property |
|
|
def scores(self): |
|
|
return self.class_scores |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pred_path = \ |
|
|
"/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/data/prediction/scannetpp_dust3r_posed" |
|
|
bins_path = \ |
|
|
"/home/jovyan/users/bulat/workspace/3drec/Indoor/OKNO/data/scannetpp/bins/points_dust3r_posed" |
|
|
out_pkl_path = \ |
|
|
"/home/jovyan/users/bulat/workspace/3drec/Indoor/OKNO/data/scannetpp/bins/scannetpp84_dust3r_posed_train10.pkl" |
|
|
gt_pkl_path = \ |
|
|
"/home/jovyan/users/lemeshko/TMP/my_pkls/scannetpp_infos_84class_train.pkl" |
|
|
|
|
|
distr = PredBBoxDistrPP(pred_path, bins_path, gt_pkl_path) |
|
|
|
|
|
with open(gt_pkl_path, 'rb') as file: |
|
|
gt_data = pickle.load(file) |
|
|
|
|
|
new_data = {"metainfo": gt_data["metainfo"]} |
|
|
data_list = [] |
|
|
|
|
|
picked_scenes = set(map(lambda x: x[:-4], os.listdir(distr.path))) |
|
|
for scene in tqdm(gt_data['data_list']): |
|
|
scene_name = distr._normalize_scene_id(scene['lidar_points']['lidar_path']) |
|
|
if scene_name not in picked_scenes: |
|
|
continue |
|
|
tmp_scene = deepcopy(scene) |
|
|
instances = distr.filter_instances_topk_by_gt(scene_name, class_agnostic=False) |
|
|
tmp_scene['instances'] = instances |
|
|
data_list.append(tmp_scene) |
|
|
|
|
|
new_data['data_list'] = data_list |
|
|
with open(out_pkl_path, 'wb') as f: |
|
|
pickle.dump(new_data, f) |
|
|
|