| import os |
| import glob |
| import torch |
| import numpy as np |
| import warnings |
| import trimesh |
| from scipy.stats import entropy |
| from sklearn.neighbors import NearestNeighbors |
| from numpy.linalg import norm |
| from tqdm.auto import tqdm |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| GENERATED_MESH_DIR = "/root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs" |
| GT_MESH_DIR = "/root/Trisf/abalation_post_processing/gt_mesh" |
|
|
| |
| NUM_POINTS_PER_MESH = 2048 |
| BATCH_SIZE = 32 |
| JSD_RESOLUTION = 28 |
|
|
| |
| |
| |
|
|
| def process_meshes_in_folder(folder_path, num_points): |
| """ |
| 加载文件夹中所有的 .obj 文件, 将它们采样成点云, 并进行归一化。 |
| """ |
| |
| mesh_files = sorted(glob.glob(os.path.join(folder_path, '*.obj'))) |
| if not mesh_files: |
| raise FileNotFoundError(f"在文件夹 '{folder_path}' 中没有找到任何 .obj 文件。") |
|
|
| all_point_clouds = [] |
| print(f"正在从 '{folder_path}' 处理 {len(mesh_files)} 个mesh...") |
|
|
| for mesh_path in tqdm(mesh_files, desc=f'处理 {os.path.basename(folder_path)}'): |
| try: |
| mesh = trimesh.load(mesh_path, process=False) |
| |
| |
| center = mesh.bounds.mean(axis=0) |
| mesh.apply_translation(-center) |
| max_dist = np.max(np.linalg.norm(mesh.vertices, axis=1)) |
| if max_dist > 0: |
| mesh.apply_scale(1.0 / max_dist) |
| |
| points, _ = trimesh.sample.sample_surface(mesh, num_points) |
| |
| if points.shape[0] != num_points: |
| |
| indices = np.random.choice(points.shape[0], num_points, replace=True) |
| points = points[indices] |
|
|
| all_point_clouds.append(points) |
| |
| except Exception as e: |
| print(f"错误:加载或处理文件 {mesh_path} 失败: {e}") |
|
|
| return np.array(all_point_clouds) |
|
|
| |
| |
| |
|
|
| _EMD_NOT_IMPL_WARNED = False |
| def emd_approx(sample, ref): |
| global _EMD_NOT_IMPL_WARNED |
| emd = torch.zeros([sample.size(0)]).to(sample) |
| if not _EMD_NOT_IMPL_WARNED: |
| _EMD_NOT_IMPL_WARNED = True |
| print('\n\n[WARNING] EMD is not implemented. Setting to zero.') |
| return emd |
|
|
| def distChamfer(a, b): |
| x, y = a, b |
| bs, num_points, points_dim = x.size() |
| xx = torch.bmm(x, x.transpose(2, 1)) |
| yy = torch.bmm(y, y.transpose(2, 1)) |
| zz = torch.bmm(x, y.transpose(2, 1)) |
| diag_ind = torch.arange(0, num_points, device=a.device).long() |
| rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) |
| ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) |
| P = (rx.transpose(2, 1) + ry - 2 * zz) |
| |
| return P.min(1)[0], P.min(2)[0] |
|
|
| def compute_cd_hd(sample_pcs, ref_pcs, batch_size): |
| """ |
| 计算平均成对的Chamfer Distance (CD) 和 Hausdorff Distance (HD)。 |
| """ |
| print("\n--- 开始计算 平均Chamfer和Hausdorff距离 ---") |
| N_sample = sample_pcs.shape[0] |
| N_ref = ref_pcs.shape[0] |
| |
| assert N_sample == N_ref, f"用于成对度量计算的集合大小必须相等, 但得到 {N_sample} 和 {N_ref}" |
|
|
| cd_all = [] |
| hd_all = [] |
| |
| iterator = range(0, N_sample, batch_size) |
| for b_start in tqdm(iterator, desc='计算 CD/HD'): |
| b_end = min(N_sample, b_start + batch_size) |
| sample_batch = sample_pcs[b_start:b_end] |
| ref_batch = ref_pcs[b_start:b_end] |
| |
| |
| dist1_sq, dist2_sq = distChamfer(sample_batch, ref_batch) |
| |
| |
| cd_batch = dist1_sq.mean(dim=1) + dist2_sq.mean(dim=1) |
| cd_all.append(cd_batch) |
| |
| |
| |
| |
| hd_batch = torch.max(dist1_sq.max(dim=1)[0], dist2_sq.max(dim=1)[0]).sqrt() |
| hd_all.append(hd_batch) |
| |
| cd_all = torch.cat(cd_all) |
| hd_all = torch.cat(hd_all) |
| |
| results = { |
| 'Chamfer-L2': cd_all.mean(), |
| 'Hausdorff': hd_all.mean(), |
| } |
| return results |
|
|
| def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, verbose=True): |
| N_sample = sample_pcs.shape[0] |
| N_ref = ref_pcs.shape[0] |
| all_cd = [] |
| iterator = range(N_sample) |
| if verbose: |
| iterator = tqdm(iterator, desc='计算点云间距离') |
| for i in iterator: |
| sample_batch = sample_pcs[i] |
| cd_lst = [] |
| sub_iterator = range(0, N_ref, batch_size) |
| for b_start in sub_iterator: |
| b_end = min(N_ref, b_start + batch_size) |
| ref_batch = ref_pcs[b_start:b_end] |
| batch_size_ref = ref_batch.size(0) |
| sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1).contiguous() |
| dl, dr = distChamfer(sample_batch_exp, ref_batch) |
| cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1)) |
| cd_lst = torch.cat(cd_lst, dim=1) |
| all_cd.append(cd_lst) |
| all_cd = torch.cat(all_cd, dim=0) |
| |
| all_emd = torch.zeros_like(all_cd) |
| return all_cd, all_emd |
|
|
| def knn(Mxx, Mxy, Myy, k, sqrt=False): |
| n0, n1 = Mxx.size(0), Myy.size(0) |
| device = Mxx.device |
|
|
| ones_tensor = torch.ones(n0, device=device) |
| zeros_tensor = torch.zeros(n1, device=device) |
| label = torch.cat((ones_tensor, zeros_tensor)) |
|
|
| M = torch.cat([torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.t(), Myy), 1)], 0) |
| if sqrt: M = M.abs().sqrt() |
|
|
| diag_inf = torch.diag(torch.full((n0 + n1,), float('inf'), device=device)) |
| val, idx = (M + diag_inf).topk(k, 0, False) |
|
|
| count = torch.zeros(n0 + n1, device=device) |
| for i in range(k): |
| count.add_(label.index_select(0, idx[i])) |
| |
| threshold = torch.full((n0 + n1,), float(k) / 2, device=device) |
| pred = (count >= threshold).float() |
| |
| return {'acc': (label == pred).float().mean()} |
|
|
| def lgan_mmd_cov(all_dist): |
| N_sample, N_ref = all_dist.shape |
| min_val, min_idx = all_dist.min(dim=1) |
| mmd_smp = min_val.mean() |
| |
| min_val_ref, _ = all_dist.min(dim=0) |
| mmd = min_val_ref.mean() |
| |
| cov = min_idx.unique().numel() / float(N_ref) |
| cov = torch.tensor(cov, device=all_dist.device) |
| |
| return {'lgan_mmd': mmd, 'lgan_cov': cov} |
|
|
| def compute_mmd_cov_1nna(sample_pcs, ref_pcs, batch_size): |
| results = {} |
| print("\n--- 开始计算 MMD-CD, COV-CD, 1-NNA-CD ---") |
| |
| M_rs_cd, _ = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size) |
| |
| res_cd = lgan_mmd_cov(M_rs_cd.t()) |
| results.update({f"{k}-CD": v for k, v in res_cd.items()}) |
| |
| M_rr_cd, _ = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size) |
| M_ss_cd, _ = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size) |
| |
| one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1) |
| results.update({"1-NNA-CD": one_nn_cd_res['acc']}) |
|
|
| return results |
|
|
| def unit_cube_grid_point_cloud(resolution, clip_sphere=False): |
| grid = np.linspace(-0.5, 0.5, resolution) |
| x, y, z = np.meshgrid(grid, grid, grid, indexing='ij') |
| grid = np.stack([x, y, z], axis=-1).reshape(-1, 3) |
| if clip_sphere: |
| grid = grid[norm(grid, axis=1) <= 0.5] |
| return grid |
|
|
| def entropy_of_occupancy_grid(pclouds, grid_resolution): |
| grid_coords = unit_cube_grid_point_cloud(grid_resolution, True) |
| grid_counters = np.zeros(len(grid_coords)) |
| nn = NearestNeighbors(n_neighbors=1).fit(grid_coords) |
| |
| for pc in tqdm(pclouds, desc='计算占据网格'): |
| _, indices = nn.kneighbors(pc) |
| indices = np.unique(indices.squeeze()) |
| grid_counters[indices] += 1 |
| return grid_counters |
|
|
| def jensen_shannon_divergence(P, Q): |
| P_ = P / (P.sum() + 1e-9) |
| Q_ = Q / (Q.sum() + 1e-9) |
| M = 0.5 * (P_ + Q_) |
| return 0.5 * (entropy(P_, M, base=2) + entropy(Q_, M, base=2)) |
|
|
| def compute_jsd(sample_pcs, ref_pcs, resolution): |
| print("\n--- 开始计算 JSD ---") |
| sample_grid_dist = entropy_of_occupancy_grid(sample_pcs, resolution) |
| ref_grid_dist = entropy_of_occupancy_grid(ref_pcs, resolution) |
| jsd = jensen_shannon_divergence(sample_grid_dist, ref_grid_dist) |
| return jsd |
|
|
| |
| |
| |
| if __name__ == '__main__': |
| |
| sample_pcs_np = process_meshes_in_folder(GENERATED_MESH_DIR, NUM_POINTS_PER_MESH) |
| ref_pcs_np = process_meshes_in_folder(GT_MESH_DIR, NUM_POINTS_PER_MESH) |
| |
| print(f"\n加载完成: {sample_pcs_np.shape[0]} 个生成点云, {ref_pcs_np.shape[0]} 个真实点云。") |
| print(f"每个点云包含 {sample_pcs_np.shape[1]} 个点。") |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"使用设备: {device}") |
| |
| sample_pcs_torch = torch.from_numpy(sample_pcs_np).float().to(device) |
| ref_pcs_torch = torch.from_numpy(ref_pcs_np).float().to(device) |
| |
| |
| metrics_results = compute_mmd_cov_1nna(sample_pcs_torch, ref_pcs_torch, BATCH_SIZE) |
| |
| |
| cd_hd_results = compute_cd_hd(sample_pcs_torch, ref_pcs_torch, BATCH_SIZE) |
| metrics_results.update(cd_hd_results) |
| |
| |
| jsd_result = compute_jsd(sample_pcs_np, ref_pcs_np, JSD_RESOLUTION) |
| |
| |
| print("\n==================================================") |
| print(" 评估结果") |
| print("==================================================") |
| |
| print("\n--- 分布质量与多样性 (Distribution Metrics) ---") |
| |
| print(f"{'lgan_mmd-CD':<12s}: {metrics_results['lgan_mmd-CD'].item():.6f} (↓ Lower is better)") |
| |
| print(f"{'lgan_cov-CD':<12s}: {metrics_results['lgan_cov-CD'].item():.6f} (↑ Higher is better)") |
| |
| print(f"{'1-NNA-CD':<12s}: {metrics_results['1-NNA-CD'].item():.6f} (→ Closer to 0.5 is better)") |
| |
| print(f"{'JSD':<12s}: {jsd_result:.6f} (↓ Lower is better)") |
|
|
| print("\n--- 平均几何保真度 (Average Geometric Fidelity) ---") |
| |
| print(f"{'Chamfer-L2':<12s}: {metrics_results['Chamfer-L2'].item():.6f} (↓ Lower is better)") |
| |
| print(f"{'Hausdorff':<12s}: {metrics_results['Hausdorff'].item():.6f} (↓ Lower is better)") |
| |
| print("==================================================") |