|
|
import numpy as np |
|
|
import os |
|
|
import pickle |
|
|
from tqdm.auto import tqdm |
|
|
from collections import defaultdict |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from copy import deepcopy |
|
|
import torch |
|
|
from tqdm.contrib.concurrent import thread_map |
|
|
|
|
|
class PredBBoxDistrPP: |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _normalize_scene_id(value): |
|
|
return value.split("_")[0] |
|
|
|
|
|
def __init__(self, path, bins_path, gt_pkl_path, confidence_threshold=0.0, topk=True): |
|
|
self.path = path |
|
|
self.bins_path = bins_path |
|
|
self.gt_pkl_path = gt_pkl_path |
|
|
|
|
|
self.class_scores = defaultdict(list) |
|
|
self.confidence_threshold = confidence_threshold |
|
|
self.topk = topk |
|
|
|
|
|
def load_pkl_scene_by_id(self, scene_id): |
|
|
""" |
|
|
Вернуть описание сцены из PKL по scene_id (без расширения). |
|
|
Поддерживает как id вида "sceneXXXX_YY", так и пути/имена с .bin. |
|
|
""" |
|
|
target_id = self._normalize_scene_id(scene_id) |
|
|
with open(self.gt_pkl_path, 'rb') as file: |
|
|
data = pickle.load(file) |
|
|
for scene in data.get('data_list', []): |
|
|
lidar_path = scene.get('lidar_points', {}).get('lidar_path') |
|
|
if not lidar_path: |
|
|
continue |
|
|
candidate_id = self._normalize_scene_id(lidar_path) |
|
|
if candidate_id == target_id: |
|
|
return scene |
|
|
return None |
|
|
|
|
|
def get_scenes(self): |
|
|
self.scene_ids = [] |
|
|
self.gt_sample_counts = defaultdict(int) |
|
|
with open(self.gt_pkl_path, 'rb') as file: |
|
|
data = pickle.load(file) |
|
|
picked_scenes = set(map(lambda x: x[:-4], os.listdir(self.path))) |
|
|
for scene in data['data_list']: |
|
|
scene_name = scene['lidar_points']['lidar_path'][:-4] |
|
|
if scene_name not in picked_scenes: |
|
|
continue |
|
|
self.scene_ids.append(scene_name) |
|
|
for instance in scene['instances']: |
|
|
inst_id = instance['bbox_label_3d'] |
|
|
self.gt_sample_counts[0] += 1 |
|
|
|
|
|
def get_scene_inst(self, scene_id): |
|
|
cls_path = f'{self.path}/{scene_id}.npz' |
|
|
cls_data = np.load(cls_path, allow_pickle=True) |
|
|
for class_id, class_score in zip(cls_data['pred_classes'], cls_data['pred_score']): |
|
|
self.class_scores[0].append(class_score) |
|
|
|
|
|
def plot_class_distr(self, class_name='all'): |
|
|
""" |
|
|
Построить распределение оценок для конкретного класса или всех классов вместе |
|
|
|
|
|
Parameters: |
|
|
class_name: str or list - название класса, 'all' для всех классов, |
|
|
или список названий классов |
|
|
""" |
|
|
if class_name == 'all': |
|
|
|
|
|
all_scores = [] |
|
|
for scores in self.class_scores.values(): |
|
|
all_scores.extend(scores) |
|
|
scores = all_scores |
|
|
display_name = 'All Classes' |
|
|
elif isinstance(class_name, list): |
|
|
|
|
|
selected_scores = [] |
|
|
for cls in class_name: |
|
|
if cls in self.class_scores: |
|
|
selected_scores.extend(self.class_scores[cls]) |
|
|
else: |
|
|
print(f"Warning: Class '{cls}' not found in class_scores") |
|
|
scores = selected_scores |
|
|
display_name = f'Classes: {", ".join(class_name[:3])}{"..." if len(class_name) > 3 else ""}' |
|
|
else: |
|
|
|
|
|
if class_name not in self.class_scores: |
|
|
print(f"Class '{class_name}' not found in class_scores") |
|
|
print(f"Available classes: {list(self.class_scores.keys())[:10]}...") |
|
|
return |
|
|
scores = self.class_scores[class_name] |
|
|
display_name = class_name |
|
|
|
|
|
if not scores: |
|
|
print(f"No scores available for: {display_name}") |
|
|
return |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
|
|
|
|
|
|
|
sns.histplot(scores, bins=30, kde=True, ax=ax, color='skyblue', |
|
|
stat='density', alpha=0.7) |
|
|
ax.set_title(f'Distribution of scores for {display_name}', fontsize=14, fontweight='bold') |
|
|
ax.set_xlabel('Score', fontsize=12) |
|
|
ax.set_ylabel('Density', fontsize=12) |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
median_score = np.median(scores) |
|
|
ax.axvline(median_score, color='green', linestyle='--', linewidth=2, |
|
|
label=f'Median: {median_score:.3f}') |
|
|
ax.axvline(np.percentile(scores, 32.45), color='red', linestyle='-', linewidth=2, |
|
|
label=f'Size bound: {np.percentile(scores, 32.45):.3f}') |
|
|
|
|
|
ax.legend() |
|
|
|
|
|
|
|
|
if class_name == 'all': |
|
|
class_info = f"Total classes: {len(self.class_scores)}" |
|
|
elif isinstance(class_name, list): |
|
|
class_info = f"Selected classes: {len(class_name)}" |
|
|
else: |
|
|
class_info = f"Class: {class_name}" |
|
|
|
|
|
stats_text = f"""Statistics for {display_name}: |
|
|
{class_info} |
|
|
Total instances: {len(scores):,} |
|
|
Mean: {np.mean(scores):.3f} |
|
|
Median: {np.median(scores):.3f} |
|
|
Std: {np.std(scores):.3f} |
|
|
Min: {np.min(scores):.3f} |
|
|
Max: {np.max(scores):.3f} |
|
|
Q1: {np.percentile(scores, 25):.3f} |
|
|
Q : {np.percentile(scores, 32.45):.3f} |
|
|
Q3: {np.percentile(scores, 75):.3f}""" |
|
|
|
|
|
|
|
|
props = dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8) |
|
|
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, fontfamily='monospace', |
|
|
verticalalignment='top', bbox=props, fontsize=10) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
|
|
|
print(stats_text) |
|
|
|
|
|
return scores |
|
|
|
|
|
|
|
|
def plot_multiple_classes(self, class_names: list): |
|
|
""" |
|
|
Сравнить распределения нескольких классов на одном графике |
|
|
""" |
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
|
|
|
|
colors = ['skyblue', 'lightcoral', 'lightgreen', 'gold', 'lightpink'] |
|
|
|
|
|
for i, cls in enumerate(class_names): |
|
|
if cls not in self.class_scores: |
|
|
print(f"Warning: Class '{cls}' not found, skipping") |
|
|
continue |
|
|
|
|
|
scores = self.class_scores[cls] |
|
|
if scores: |
|
|
sns.kdeplot(scores, ax=ax, label=cls, color=colors[i % len(colors)], |
|
|
linewidth=2, alpha=0.8) |
|
|
|
|
|
ax.set_title('Score Distribution Comparison', fontsize=14, fontweight='bold') |
|
|
ax.set_xlabel('Score', fontsize=12) |
|
|
ax.set_ylabel('Density', fontsize=12) |
|
|
ax.grid(True, alpha=0.3) |
|
|
ax.legend() |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
def get_class_lowerbound(self, class_name='all', percentile=32.45): |
|
|
if class_name == 'all': |
|
|
|
|
|
all_scores = [] |
|
|
for scores in self.class_scores.values(): |
|
|
all_scores.extend(scores) |
|
|
scores = all_scores |
|
|
elif isinstance(class_name, list): |
|
|
selected_scores = [] |
|
|
for cls in class_name: |
|
|
if cls in self.class_scores: |
|
|
selected_scores.extend(self.class_scores[cls]) |
|
|
else: |
|
|
print(f"Warning: Class '{cls}' not found in class_scores") |
|
|
scores = selected_scores |
|
|
else: |
|
|
|
|
|
if class_name not in self.class_scores: |
|
|
print(f"Class '{class_name}' not found in class_scores") |
|
|
print(f"Available classes: {list(self.class_scores.keys())[:10]}...") |
|
|
return |
|
|
scores = self.class_scores[class_name] |
|
|
|
|
|
return np.percentile(scores, percentile) |
|
|
|
|
|
def get_bboxes_by_masks(self, masks, points): |
|
|
boxes = [] |
|
|
for mask in masks: |
|
|
object_points = points[mask][:, :3] |
|
|
|
|
|
|
|
|
xyz_min = object_points.quantile(0.01, dim=0) |
|
|
xyz_max = object_points.quantile(0.99, dim=0) |
|
|
center = (xyz_max + xyz_min) / 2 |
|
|
size = xyz_max - xyz_min |
|
|
box = torch.cat((center, size, torch.zeros_like(center)[:1])) |
|
|
boxes.append(box) |
|
|
assert len(boxes) != 0, "Why 0 masks in scene?" |
|
|
boxes = torch.stack(boxes) |
|
|
return boxes |
|
|
|
|
|
def get_scene_instances(self, scene_name, score_bounds, class_agnostic): |
|
|
instances = [] |
|
|
points_path = f'{self.bins_path}/{scene_name}.bin' |
|
|
points = torch.from_numpy(np.fromfile(points_path, dtype=np.float32).reshape((-1, 6))) |
|
|
|
|
|
gt_scene = self.load_pkl_scene_by_id(scene_name) |
|
|
if gt_scene is not None and 'axis_align_matrix' in gt_scene: |
|
|
a = torch.as_tensor(np.array(gt_scene['axis_align_matrix'], dtype=np.float32)) |
|
|
R = a[:3, :3] |
|
|
t = a[:3, 3] |
|
|
xyz = points[:, :3] |
|
|
points[:, :3] = xyz @ R.T + t |
|
|
cls_path = f'{self.path}/{scene_name}.npz' |
|
|
cls_data = np.load(cls_path, allow_pickle=True) |
|
|
pred_masks = torch.from_numpy(cls_data['pred_masks']).T |
|
|
pred_classes = cls_data['pred_classes'] |
|
|
pred_scores = cls_data['pred_score'] |
|
|
boxes = self.get_bboxes_by_masks(pred_masks, points) |
|
|
for box, pred_class, pred_score in zip(boxes, pred_classes, pred_scores): |
|
|
if pred_score > score_bounds.get(pred_class, 0): |
|
|
write_class = 0 |
|
|
instances.append({'bbox_3d': box.numpy().tolist(), 'bbox_label_3d': write_class}) |
|
|
return instances |
|
|
|
|
|
def filter_instances_topk_by_gt(self, scene_name, class_agnostic=True): |
|
|
""" |
|
|
Фильтрует предсказанные инстансы по top-K, где K = количество GT-инстансов. |
|
|
Шаги: |
|
|
1) Берем все маски, переводим в 3D bbox-ы |
|
|
2) Сортируем по убыванию предикт-скор |
|
|
3) Оставляем top-K, где K равно числу GT-инстансов в PKL |
|
|
Возвращает список инстансов в формате mmdet3d (bbox_3d, bbox_label_3d). |
|
|
""" |
|
|
scene_id = self._normalize_scene_id(scene_name) |
|
|
gt_scene = self.load_pkl_scene_by_id(scene_name) |
|
|
gt_count = len(gt_scene.get('instances', [])) if gt_scene else 0 |
|
|
if gt_count <= 0: |
|
|
return [] |
|
|
|
|
|
points_path = f'{self.bins_path}/{scene_id}_point.bin' |
|
|
points = torch.from_numpy(np.fromfile(points_path, dtype=np.float32).reshape((-1, 6))) |
|
|
|
|
|
if gt_scene is not None and 'axis_align_matrix' in gt_scene: |
|
|
a = torch.as_tensor(np.array(gt_scene['axis_align_matrix'], dtype=np.float32)) |
|
|
print(a) |
|
|
R = a[:3, :3] |
|
|
t = a[:3, 3] |
|
|
xyz = points[:, :3] |
|
|
points[:, :3] = xyz @ R.T + t |
|
|
cls_path = f'{self.path}/{scene_id}.npz' |
|
|
cls_data = np.load(cls_path, allow_pickle=True) |
|
|
pred_masks = torch.from_numpy(cls_data['pred_masks']).T |
|
|
pred_classes = cls_data['pred_classes'] |
|
|
pred_scores = cls_data['pred_score'] |
|
|
|
|
|
mask = pred_scores >= self.confidence_threshold |
|
|
|
|
|
pred_masks = torch.from_numpy(cls_data['pred_masks']).T |
|
|
pred_classes = cls_data['pred_classes'] |
|
|
pred_scores = cls_data['pred_score'] |
|
|
|
|
|
if len(pred_scores) == 0: |
|
|
return [] |
|
|
|
|
|
if self.topk: |
|
|
topk = int(min(gt_count, len(pred_scores))) |
|
|
else: |
|
|
topk = len(pred_scores) |
|
|
np_topk_indices = np.argsort(-pred_scores)[:topk] |
|
|
|
|
|
|
|
|
torch_topk_indices = torch.as_tensor(np_topk_indices, dtype=torch.long) |
|
|
selected_masks = pred_masks[torch_topk_indices] |
|
|
boxes = self.get_bboxes_by_masks(selected_masks, points) |
|
|
selected_classes = pred_classes[np_topk_indices] |
|
|
|
|
|
instances = [] |
|
|
for box, pred_class in zip(boxes, selected_classes): |
|
|
write_class = 0 |
|
|
instances.append({'bbox_3d': box.numpy().tolist(), 'bbox_label_3d': write_class}) |
|
|
return instances |
|
|
|
|
|
|
|
|
@property |
|
|
def scores(self): |
|
|
return self.class_scores |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pred_path = \ |
|
|
"/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/data/prediction/arkit_vggt" |
|
|
bins_path = \ |
|
|
"/home/jovyan/users/bulat/workspace/3drec/Indoor/OKNO/data/arkitscenes/points_vggt" |
|
|
out_pkl_path = \ |
|
|
"arkit_vggt_ca_ct05_topk_false.pkl" |
|
|
gt_pkl_path = \ |
|
|
"/home/jovyan/users/bulat/workspace/3drec/Indoor/OKNO/data/arkitscenes/arkitscenes_offline_infos_train.pkl" |
|
|
confidence_threshold = 0.5 |
|
|
distr = PredBBoxDistrPP(pred_path, bins_path, gt_pkl_path, confidence_threshold=confidence_threshold, topk=False) |
|
|
|
|
|
with open(gt_pkl_path, 'rb') as file: |
|
|
gt_data = pickle.load(file) |
|
|
|
|
|
new_data = {"metainfo": gt_data["metainfo"]} |
|
|
data_list = [] |
|
|
|
|
|
picked_scenes = set(map(lambda x: x.split("_")[0], os.listdir(bins_path))) |
|
|
scene_names = [distr._normalize_scene_id(scene['lidar_points']['lidar_path']) for scene in gt_data['data_list']] |
|
|
indices = [i for i, scene_name in enumerate(scene_names) if scene_name in picked_scenes] |
|
|
data = [scene_name for scene_name in scene_names if scene_name in picked_scenes] |
|
|
instances = thread_map(distr.filter_instances_topk_by_gt, data, chunksize=128) |
|
|
for i, instance in enumerate(instances): |
|
|
tmp_scene = deepcopy(gt_data['data_list'][indices[i]]) |
|
|
tmp_scene['instances'] = instance |
|
|
data_list.append(tmp_scene) |
|
|
|
|
|
new_data['data_list'] = data_list |
|
|
with open(out_pkl_path, 'wb') as f: |
|
|
pickle.dump(new_data, f) |
|
|
|