|
|
import numpy as np |
|
|
import os |
|
|
import pickle |
|
|
from tqdm.auto import tqdm |
|
|
from collections import defaultdict |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import torch |
|
|
|
|
|
class PredBBoxDistrPP: |
|
|
|
|
|
SCANNET_IDS = [4, 3, 6, 5, 9, 7, 8, 10, 12, 11, 14, 13, 23, 17, 18, 24, 25, 27, 28, 47, 88, 35, 36, 42, 45, 58, 49, 54, 56, 59, 60, 63, 67, 68, 102, 71, 72, 74, 81, 83, 90, 96, 122, 416, 106, 111, 117, 126, 129, 132, 155, 166, 173, 188, 300, 199, 204, 214, 219, 253, 299, 265, 273, 352, 295, 296, 301, 305, 312, 342, 358, 364, 368, 387, 395, 396, 403, 405, 414, 443, 469, 515, 744, 1157] |
|
|
|
|
|
SCANNET_LABELS = ['table', 'door', 'ceiling lamp', 'cabinet', 'blinds', 'curtain', 'chair', 'storage cabinet', 'office chair', 'bookshelf', 'whiteboard', 'window', 'box', |
|
|
'monitor', 'shelf', 'heater', 'kitchen cabinet', 'sofa', 'bed', 'trash can', 'book', 'plant', 'blanket', 'tv', 'computer tower', 'refrigerator', 'jacket', |
|
|
'sink', 'bag', 'picture', 'pillow', 'towel', 'suitcase', 'backpack', 'crate', 'keyboard', 'rack', 'toilet', 'printer', 'poster', 'painting', 'microwave', 'shoes', |
|
|
'socket', 'bottle', 'bucket', 'cushion', 'basket', 'shoe rack', 'telephone', 'file folder', 'laptop', 'plant pot', 'exhaust fan', 'cup', 'coat hanger', 'light switch', |
|
|
'speaker', 'table lamp', 'kettle', 'smoke detector', 'container', 'power strip', 'slippers', 'paper bag', 'mouse', 'cutting board', 'toilet paper', 'paper towel', |
|
|
'pot', 'clock', 'pan', 'tap', 'jar', 'soap dispenser', 'binder', 'bowl', 'tissue box', 'whiteboard eraser', 'toilet brush', 'spray bottle', 'headphones', 'stapler', 'marker'] |
|
|
|
|
|
ID2LABEL = dict(zip(SCANNET_IDS, SCANNET_LABELS)) |
|
|
|
|
|
LABEL2ID = dict(zip(SCANNET_LABELS, SCANNET_IDS)) |
|
|
|
|
|
INV_SCANNET_IDS = {idx: i for i, idx in enumerate(SCANNET_IDS)} |
|
|
|
|
|
@staticmethod |
|
|
def _normalize_scene_id(value): |
|
|
base = os.path.basename(value) |
|
|
if base.endswith('.bin'): |
|
|
base = base[:-4] |
|
|
else: |
|
|
base = os.path.splitext(base)[0] |
|
|
return base |
|
|
|
|
|
def __init__(self, path, bins_path): |
|
|
self.path = path |
|
|
self.bins_path = bins_path |
|
|
self.get_scenes() |
|
|
self.class_scores = defaultdict(list) |
|
|
for scene_id in self.scene_ids: |
|
|
self.get_scene_inst(scene_id) |
|
|
|
|
|
self.sorted_names = sorted(self.SCANNET_LABELS, key=lambda x: -len(self.class_scores.get(x, []))) |
|
|
|
|
|
|
|
|
|
|
|
def get_scenes(self): |
|
|
self.scene_ids = [] |
|
|
if not os.path.isdir(self.path): |
|
|
return |
|
|
pred_files = [f for f in os.listdir(self.path) if f.endswith('.npz')] |
|
|
for fname in pred_files: |
|
|
scene_name = os.path.splitext(fname)[0] |
|
|
bin_path = os.path.join(self.bins_path, f"{scene_name}.bin") |
|
|
if os.path.exists(bin_path): |
|
|
self.scene_ids.append(scene_name) |
|
|
|
|
|
def get_scene_inst(self, scene_id): |
|
|
cls_path = f'{self.path}/{scene_id}.npz' |
|
|
cls_data = np.load(cls_path, allow_pickle=True) |
|
|
for class_id, class_score in zip(cls_data['pred_classes'], cls_data['pred_score']): |
|
|
self.class_scores[self.ID2LABEL[class_id]].append(class_score) |
|
|
|
|
|
def plot_class_distr(self, class_name='all'): |
|
|
""" |
|
|
Построить распределение оценок для конкретного класса или всех классов вместе |
|
|
|
|
|
Parameters: |
|
|
class_name: str or list - название класса, 'all' для всех классов, |
|
|
или список названий классов |
|
|
""" |
|
|
if class_name == 'all': |
|
|
|
|
|
all_scores = [] |
|
|
for scores in self.class_scores.values(): |
|
|
all_scores.extend(scores) |
|
|
scores = all_scores |
|
|
display_name = 'All Classes' |
|
|
elif isinstance(class_name, list): |
|
|
|
|
|
selected_scores = [] |
|
|
for cls in class_name: |
|
|
if cls in self.class_scores: |
|
|
selected_scores.extend(self.class_scores[cls]) |
|
|
else: |
|
|
print(f"Warning: Class '{cls}' not found in class_scores") |
|
|
scores = selected_scores |
|
|
display_name = f'Classes: {", ".join(class_name[:3])}{"..." if len(class_name) > 3 else ""}' |
|
|
else: |
|
|
|
|
|
if class_name not in self.class_scores: |
|
|
print(f"Class '{class_name}' not found in class_scores") |
|
|
print(f"Available classes: {list(self.class_scores.keys())[:10]}...") |
|
|
return |
|
|
scores = self.class_scores[class_name] |
|
|
display_name = class_name |
|
|
|
|
|
if not scores: |
|
|
print(f"No scores available for: {display_name}") |
|
|
return |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
|
|
|
|
|
|
|
sns.histplot(scores, bins=30, kde=True, ax=ax, color='skyblue', |
|
|
stat='density', alpha=0.7) |
|
|
ax.set_title(f'Distribution of scores for {display_name}', fontsize=14, fontweight='bold') |
|
|
ax.set_xlabel('Score', fontsize=12) |
|
|
ax.set_ylabel('Density', fontsize=12) |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
median_score = np.median(scores) |
|
|
ax.axvline(median_score, color='green', linestyle='--', linewidth=2, |
|
|
label=f'Median: {median_score:.3f}') |
|
|
ax.axvline(np.percentile(scores, 32.45), color='red', linestyle='-', linewidth=2, |
|
|
label=f'Size bound: {np.percentile(scores, 32.45):.3f}') |
|
|
|
|
|
ax.legend() |
|
|
|
|
|
|
|
|
if class_name == 'all': |
|
|
class_info = f"Total classes: {len(self.class_scores)}" |
|
|
elif isinstance(class_name, list): |
|
|
class_info = f"Selected classes: {len(class_name)}" |
|
|
else: |
|
|
class_info = f"Class: {class_name}" |
|
|
|
|
|
stats_text = f"""Statistics for {display_name}: |
|
|
{class_info} |
|
|
Total instances: {len(scores):,} |
|
|
Mean: {np.mean(scores):.3f} |
|
|
Median: {np.median(scores):.3f} |
|
|
Std: {np.std(scores):.3f} |
|
|
Min: {np.min(scores):.3f} |
|
|
Max: {np.max(scores):.3f} |
|
|
Q1: {np.percentile(scores, 25):.3f} |
|
|
Q : {np.percentile(scores, 32.45):.3f} |
|
|
Q3: {np.percentile(scores, 75):.3f}""" |
|
|
|
|
|
|
|
|
props = dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8) |
|
|
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, fontfamily='monospace', |
|
|
verticalalignment='top', bbox=props, fontsize=10) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
|
|
|
print(stats_text) |
|
|
|
|
|
return scores |
|
|
|
|
|
|
|
|
def plot_multiple_classes(self, class_names: list): |
|
|
""" |
|
|
Сравнить распределения нескольких классов на одном графике |
|
|
""" |
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
|
|
|
|
colors = ['skyblue', 'lightcoral', 'lightgreen', 'gold', 'lightpink'] |
|
|
|
|
|
for i, cls in enumerate(class_names): |
|
|
if cls not in self.class_scores: |
|
|
print(f"Warning: Class '{cls}' not found, skipping") |
|
|
continue |
|
|
|
|
|
scores = self.class_scores[cls] |
|
|
if scores: |
|
|
sns.kdeplot(scores, ax=ax, label=cls, color=colors[i % len(colors)], |
|
|
linewidth=2, alpha=0.8) |
|
|
|
|
|
ax.set_title('Score Distribution Comparison', fontsize=14, fontweight='bold') |
|
|
ax.set_xlabel('Score', fontsize=12) |
|
|
ax.set_ylabel('Density', fontsize=12) |
|
|
ax.grid(True, alpha=0.3) |
|
|
ax.legend() |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
def get_class_lowerbound(self, class_name='all', percentile=32.45): |
|
|
if class_name == 'all': |
|
|
|
|
|
all_scores = [] |
|
|
for scores in self.class_scores.values(): |
|
|
all_scores.extend(scores) |
|
|
scores = all_scores |
|
|
elif isinstance(class_name, list): |
|
|
selected_scores = [] |
|
|
for cls in class_name: |
|
|
if cls in self.class_scores: |
|
|
selected_scores.extend(self.class_scores[cls]) |
|
|
else: |
|
|
print(f"Warning: Class '{cls}' not found in class_scores") |
|
|
scores = selected_scores |
|
|
else: |
|
|
|
|
|
if class_name not in self.class_scores: |
|
|
print(f"Class '{class_name}' not found in class_scores") |
|
|
print(f"Available classes: {list(self.class_scores.keys())[:10]}...") |
|
|
return |
|
|
scores = self.class_scores[class_name] |
|
|
|
|
|
return np.percentile(scores, percentile) |
|
|
|
|
|
def get_bboxes_by_masks(self, masks, points): |
|
|
boxes = [] |
|
|
for mask in masks: |
|
|
object_points = points[mask][:, :3] |
|
|
|
|
|
|
|
|
xyz_min = object_points.quantile(0.01, dim=0) |
|
|
xyz_max = object_points.quantile(0.99, dim=0) |
|
|
center = (xyz_max + xyz_min) / 2 |
|
|
size = xyz_max - xyz_min |
|
|
box = torch.cat((center, size)) |
|
|
boxes.append(box) |
|
|
assert len(boxes) != 0, "Why 0 masks in scene?" |
|
|
boxes = torch.stack(boxes) |
|
|
return boxes |
|
|
|
|
|
def filter_instances_by_confidence(self, scene_name, threshold=0.5, class_agnostic=False): |
|
|
instances = [] |
|
|
points_path = f'{self.bins_path}/{scene_name}.bin' |
|
|
if not os.path.exists(points_path): |
|
|
return instances |
|
|
points = torch.from_numpy(np.fromfile(points_path, dtype=np.float32).reshape((-1, 6))) |
|
|
cls_path = f'{self.path}/{scene_name}.npz' |
|
|
if not os.path.exists(cls_path): |
|
|
return instances |
|
|
cls_data = np.load(cls_path, allow_pickle=True) |
|
|
pred_masks = torch.from_numpy(cls_data['pred_masks']).T |
|
|
pred_classes = cls_data['pred_classes'] |
|
|
pred_scores = cls_data['pred_score'] |
|
|
if len(pred_scores) == 0: |
|
|
return instances |
|
|
np_indices = np.where(pred_scores >= threshold)[0] |
|
|
if np_indices.size == 0: |
|
|
return instances |
|
|
torch_indices = torch.as_tensor(np_indices, dtype=torch.long) |
|
|
selected_masks = pred_masks[torch_indices] |
|
|
boxes = self.get_bboxes_by_masks(selected_masks, points) |
|
|
selected_classes = pred_classes[np_indices] |
|
|
for box, pred_class in zip(boxes, selected_classes): |
|
|
write_class = 0 if class_agnostic else self.INV_SCANNET_IDS[pred_class] |
|
|
instances.append({'bbox_3d': box.numpy().tolist(), 'bbox_label_3d': write_class}) |
|
|
return instances |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_pkl_by_confidence(self, pkl_path, threshold=0.5, class_agnostic=False): |
|
|
new_data: dict = {"metainfo": {'categories': {'table': 0, 'door': 1, 'ceiling lamp': 2, 'cabinet': 3, 'blinds': 4, 'curtain': 5, 'chair': 6, 'storage cabinet': 7, 'office chair': 8, 'bookshelf': 9, 'whiteboard': 10, 'window': 11, 'box': 12, 'monitor': 13, 'shelf': 14, 'heater': 15, 'kitchen cabinet': 16, 'sofa': 17, 'bed': 18, 'trash can': 19, 'book': 20, 'plant': 21, 'blanket': 22, 'tv': 23, 'computer tower': 24, 'refrigerator': 25, 'jacket': 26, 'sink': 27, 'bag': 28, 'picture': 29, 'pillow': 30, 'towel': 31, 'suitcase': 32, 'backpack': 33, 'crate': 34, 'keyboard': 35, 'rack': 36, 'toilet': 37, 'printer': 38, 'poster': 39, 'painting': 40, 'microwave': 41, 'shoes': 42, 'socket': 43, 'bottle': 44, 'bucket': 45, 'cushion': 46, 'basket': 47, 'shoe rack': 48, 'telephone': 49, 'file folder': 50, 'laptop': 51, 'plant pot': 52, 'exhaust fan': 53, 'cup': 54, 'coat hanger': 55, 'light switch': 56, 'speaker': 57, 'table lamp': 58, 'kettle': 59, 'smoke detector': 60, 'container': 61, 'power strip': 62, 'slippers': 63, 'paper bag': 64, 'mouse': 65, 'cutting board': 66, 'toilet paper': 67, 'paper towel': 68, 'pot': 69, 'clock': 70, 'pan': 71, 'tap': 72, 'jar': 73, 'soap dispenser': 74, 'binder': 75, 'bowl': 76, 'tissue box': 77, 'whiteboard eraser': 78, 'toilet brush': 79, 'spray bottle': 80, 'headphones': 81, 'stapler': 82, 'marker': 83}, 'dataset': 'scannetpp', 'info_version': '1.0'}} |
|
|
data_list = [] |
|
|
for scene_name in tqdm(self.scene_ids): |
|
|
instances = self.filter_instances_by_confidence(scene_name, threshold=threshold, class_agnostic=class_agnostic) |
|
|
scene_entry = { |
|
|
'lidar_points': { |
|
|
'num_pts_feats': 6, |
|
|
'lidar_path': f'{scene_name}.bin', |
|
|
}, |
|
|
'instances': instances, |
|
|
'pts_semantic_mask_path': f'{scene_name}.bin', |
|
|
'pts_instance_mask_path': f'{scene_name}.bin', |
|
|
'axis_align_matrix': np.eye(4, dtype=np.float32), |
|
|
} |
|
|
data_list.append(scene_entry) |
|
|
new_data['data_list'] = data_list |
|
|
with open(pkl_path, 'wb') as file: |
|
|
pickle.dump(new_data, file) |
|
|
|
|
|
@property |
|
|
def scores(self): |
|
|
return self.class_scores |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pred_path = \ |
|
|
"/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/data/prediction/scannetpp_v2_dust3r_unposed" |
|
|
bins_path = \ |
|
|
"/home/jovyan/users/bulat/workspace/3drec/Indoor/OKNO/data/scannetpp/bins/points_dust3r_v2_unposed" |
|
|
out_pkl_path = \ |
|
|
"/home/jovyan/users/bulat/workspace/3drec/Indoor/OKNO/data/scannetpp/bins/scannetpp84_v2_dust3r_unposed_train.pkl" |
|
|
threshold = 0.5 |
|
|
|
|
|
distr = PredBBoxDistrPP(pred_path, bins_path) |
|
|
distr.build_pkl_by_confidence(out_pkl_path, threshold=threshold, class_agnostic=False) |
|
|
|