ndjadjafbagk / metric.py
udbbdh's picture
Upload folder using huggingface_hub
7340df2 verified
import os
import glob
import torch
import numpy as np
import warnings
import trimesh
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
from tqdm.auto import tqdm
# ==============================================================================
# 用户配置 (User Configuration)
# ==============================================================================
# --- 路径设置 ---
# !!重要提示!!
# 当您计算真实指标时, 请确保这两个路径指向不同的文件夹
# 这里为了方便测试, 设置为相同路径。当两个路径相同时:
# MMD->0, COV->1.0, 1-NNA->0.5, JSD->0
# 而 CD 和 HD 会是一个很小的值, 代表同一mesh两次不同采样的差异。
# GENERATED_MESH_DIR = "/root/mesh_split_200complex/mesh_split_200complex_test" # 存放生成的 .obj 文件的文件夹路径
# GT_MESH_DIR = "/root/mesh_split_200complex/mesh_split_200complex_test" # 存放真实的 .obj 文件的文件夹路径
GENERATED_MESH_DIR = "/root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs" # 存放生成的 .obj 文件的文件夹路径
GT_MESH_DIR = "/root/Trisf/abalation_post_processing/gt_mesh" # 存放真实的 .obj 文件的文件夹路径
# --- 采样和计算参数 ---
NUM_POINTS_PER_MESH = 2048 # 从每个mesh表面采样的点数
BATCH_SIZE = 32 # 计算指标时使用的批次大小,根据显存调整
JSD_RESOLUTION = 28 # JSD计算中体素网格的分辨率
# ==============================================================================
# 核心功能函数: Mesh处理
# ==============================================================================
def process_meshes_in_folder(folder_path, num_points):
"""
加载文件夹中所有的 .obj 文件, 将它们采样成点云, 并进行归一化。
"""
# 按文件名排序以确保一一对应
mesh_files = sorted(glob.glob(os.path.join(folder_path, '*.obj')))
if not mesh_files:
raise FileNotFoundError(f"在文件夹 '{folder_path}' 中没有找到任何 .obj 文件。")
all_point_clouds = []
print(f"正在从 '{folder_path}' 处理 {len(mesh_files)} 个mesh...")
for mesh_path in tqdm(mesh_files, desc=f'处理 {os.path.basename(folder_path)}'):
try:
mesh = trimesh.load(mesh_path, process=False)
# 归一化: 移动到原点并缩放到单位球体内
center = mesh.bounds.mean(axis=0)
mesh.apply_translation(-center)
max_dist = np.max(np.linalg.norm(mesh.vertices, axis=1))
if max_dist > 0:
mesh.apply_scale(1.0 / max_dist)
points, _ = trimesh.sample.sample_surface(mesh, num_points)
if points.shape[0] != num_points:
# print(f"警告: {mesh_path} 采样点数 {points.shape[0]} != {num_points}, 进行重采样。")
indices = np.random.choice(points.shape[0], num_points, replace=True)
points = points[indices]
all_point_clouds.append(points)
except Exception as e:
print(f"错误:加载或处理文件 {mesh_path} 失败: {e}")
return np.array(all_point_clouds)
# ==============================================================================
# 评估指标代码 (来自 PointFlow 及新增)
# ==============================================================================
_EMD_NOT_IMPL_WARNED = False
def emd_approx(sample, ref):
global _EMD_NOT_IMPL_WARNED
emd = torch.zeros([sample.size(0)]).to(sample)
if not _EMD_NOT_IMPL_WARNED:
_EMD_NOT_IMPL_WARNED = True
print('\n\n[WARNING] EMD is not implemented. Setting to zero.')
return emd
def distChamfer(a, b):
x, y = a, b
bs, num_points, points_dim = x.size()
xx = torch.bmm(x, x.transpose(2, 1))
yy = torch.bmm(y, y.transpose(2, 1))
zz = torch.bmm(x, y.transpose(2, 1))
diag_ind = torch.arange(0, num_points, device=a.device).long()
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
P = (rx.transpose(2, 1) + ry - 2 * zz)
# P is batch_size x n_points x n_points matrix of squared distances
return P.min(1)[0], P.min(2)[0]
def compute_cd_hd(sample_pcs, ref_pcs, batch_size):
"""
计算平均成对的Chamfer Distance (CD) 和 Hausdorff Distance (HD)。
"""
print("\n--- 开始计算 平均Chamfer和Hausdorff距离 ---")
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
assert N_sample == N_ref, f"用于成对度量计算的集合大小必须相等, 但得到 {N_sample}{N_ref}"
cd_all = []
hd_all = []
iterator = range(0, N_sample, batch_size)
for b_start in tqdm(iterator, desc='计算 CD/HD'):
b_end = min(N_sample, b_start + batch_size)
sample_batch = sample_pcs[b_start:b_end]
ref_batch = ref_pcs[b_start:b_end]
# distChamfer返回的是平方距离
dist1_sq, dist2_sq = distChamfer(sample_batch, ref_batch)
# 计算 Chamfer Distance
cd_batch = dist1_sq.mean(dim=1) + dist2_sq.mean(dim=1)
cd_all.append(cd_batch)
# 计算 Hausdorff Distance
# HD = max(max(min_dist_1), max(min_dist_2))
# 我们需要对平方距离开方来得到真实距离
hd_batch = torch.max(dist1_sq.max(dim=1)[0], dist2_sq.max(dim=1)[0]).sqrt()
hd_all.append(hd_batch)
cd_all = torch.cat(cd_all)
hd_all = torch.cat(hd_all)
results = {
'Chamfer-L2': cd_all.mean(),
'Hausdorff': hd_all.mean(),
}
return results
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, verbose=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
all_cd = []
iterator = range(N_sample)
if verbose:
iterator = tqdm(iterator, desc='计算点云间距离')
for i in iterator:
sample_batch = sample_pcs[i]
cd_lst = []
sub_iterator = range(0, N_ref, batch_size)
for b_start in sub_iterator:
b_end = min(N_ref, b_start + batch_size)
ref_batch = ref_pcs[b_start:b_end]
batch_size_ref = ref_batch.size(0)
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1).contiguous()
dl, dr = distChamfer(sample_batch_exp, ref_batch)
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
cd_lst = torch.cat(cd_lst, dim=1)
all_cd.append(cd_lst)
all_cd = torch.cat(all_cd, dim=0)
# EMD is not implemented, so we return a dummy tensor for it
all_emd = torch.zeros_like(all_cd)
return all_cd, all_emd
def knn(Mxx, Mxy, Myy, k, sqrt=False):
n0, n1 = Mxx.size(0), Myy.size(0)
device = Mxx.device
ones_tensor = torch.ones(n0, device=device)
zeros_tensor = torch.zeros(n1, device=device)
label = torch.cat((ones_tensor, zeros_tensor))
M = torch.cat([torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.t(), Myy), 1)], 0)
if sqrt: M = M.abs().sqrt()
diag_inf = torch.diag(torch.full((n0 + n1,), float('inf'), device=device))
val, idx = (M + diag_inf).topk(k, 0, False)
count = torch.zeros(n0 + n1, device=device)
for i in range(k):
count.add_(label.index_select(0, idx[i]))
threshold = torch.full((n0 + n1,), float(k) / 2, device=device)
pred = (count >= threshold).float()
return {'acc': (label == pred).float().mean()}
def lgan_mmd_cov(all_dist):
N_sample, N_ref = all_dist.shape
min_val, min_idx = all_dist.min(dim=1) # For each sample, find closest ref
mmd_smp = min_val.mean() # MMD-smp
min_val_ref, _ = all_dist.min(dim=0) # For each ref, find closest sample
mmd = min_val_ref.mean() # MMD-ref
cov = min_idx.unique().numel() / float(N_ref)
cov = torch.tensor(cov, device=all_dist.device)
return {'lgan_mmd': mmd, 'lgan_cov': cov}
def compute_mmd_cov_1nna(sample_pcs, ref_pcs, batch_size):
results = {}
print("\n--- 开始计算 MMD-CD, COV-CD, 1-NNA-CD ---")
M_rs_cd, _ = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size) # ref vs sample
res_cd = lgan_mmd_cov(M_rs_cd.t()) # Transpose to get sample vs ref
results.update({f"{k}-CD": v for k, v in res_cd.items()})
M_rr_cd, _ = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size)
M_ss_cd, _ = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size)
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1)
results.update({"1-NNA-CD": one_nn_cd_res['acc']})
return results
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
grid = np.linspace(-0.5, 0.5, resolution)
x, y, z = np.meshgrid(grid, grid, grid, indexing='ij')
grid = np.stack([x, y, z], axis=-1).reshape(-1, 3)
if clip_sphere:
grid = grid[norm(grid, axis=1) <= 0.5]
return grid
def entropy_of_occupancy_grid(pclouds, grid_resolution):
grid_coords = unit_cube_grid_point_cloud(grid_resolution, True)
grid_counters = np.zeros(len(grid_coords))
nn = NearestNeighbors(n_neighbors=1).fit(grid_coords)
for pc in tqdm(pclouds, desc='计算占据网格'):
_, indices = nn.kneighbors(pc)
indices = np.unique(indices.squeeze())
grid_counters[indices] += 1
return grid_counters
def jensen_shannon_divergence(P, Q):
P_ = P / (P.sum() + 1e-9)
Q_ = Q / (Q.sum() + 1e-9)
M = 0.5 * (P_ + Q_)
return 0.5 * (entropy(P_, M, base=2) + entropy(Q_, M, base=2))
def compute_jsd(sample_pcs, ref_pcs, resolution):
print("\n--- 开始计算 JSD ---")
sample_grid_dist = entropy_of_occupancy_grid(sample_pcs, resolution)
ref_grid_dist = entropy_of_occupancy_grid(ref_pcs, resolution)
jsd = jensen_shannon_divergence(sample_grid_dist, ref_grid_dist)
return jsd
# ==============================================================================
# 主执行函数 (Main Execution)
# ==============================================================================
if __name__ == '__main__':
# 1. 加载并处理Meshes为点云 (Numpy arrays)
sample_pcs_np = process_meshes_in_folder(GENERATED_MESH_DIR, NUM_POINTS_PER_MESH)
ref_pcs_np = process_meshes_in_folder(GT_MESH_DIR, NUM_POINTS_PER_MESH)
print(f"\n加载完成: {sample_pcs_np.shape[0]} 个生成点云, {ref_pcs_np.shape[0]} 个真实点云。")
print(f"每个点云包含 {sample_pcs_np.shape[1]} 个点。")
# 2. 设置设备并转换数据为PyTorch Tensors
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
sample_pcs_torch = torch.from_numpy(sample_pcs_np).float().to(device)
ref_pcs_torch = torch.from_numpy(ref_pcs_np).float().to(device)
# 3. 计算分布度量: MMD, COV, 1-NNA (使用PyTorch)
metrics_results = compute_mmd_cov_1nna(sample_pcs_torch, ref_pcs_torch, BATCH_SIZE)
# 4. 计算成对几何度量: CD, HD (使用PyTorch)
cd_hd_results = compute_cd_hd(sample_pcs_torch, ref_pcs_torch, BATCH_SIZE)
metrics_results.update(cd_hd_results) # 合并结果
# 5. 计算JSD (使用Numpy)
jsd_result = compute_jsd(sample_pcs_np, ref_pcs_np, JSD_RESOLUTION)
# 6. 打印最终结果
print("\n==================================================")
print(" 评估结果")
print("==================================================")
print("\n--- 分布质量与多样性 (Distribution Metrics) ---")
# MMD: 越低越好 (质量)
print(f"{'lgan_mmd-CD':<12s}: {metrics_results['lgan_mmd-CD'].item():.6f} (↓ Lower is better)")
# COV: 越高越好 (多样性)
print(f"{'lgan_cov-CD':<12s}: {metrics_results['lgan_cov-CD'].item():.6f} (↑ Higher is better)")
# 1-NNA: 越接近0.5越好 (真实性)
print(f"{'1-NNA-CD':<12s}: {metrics_results['1-NNA-CD'].item():.6f} (→ Closer to 0.5 is better)")
# JSD: 越低越好 (分布相似性)
print(f"{'JSD':<12s}: {jsd_result:.6f} (↓ Lower is better)")
print("\n--- 平均几何保真度 (Average Geometric Fidelity) ---")
# CD: 越低越好
print(f"{'Chamfer-L2':<12s}: {metrics_results['Chamfer-L2'].item():.6f} (↓ Lower is better)")
# HD: 越低越好
print(f"{'Hausdorff':<12s}: {metrics_results['Hausdorff'].item():.6f} (↓ Lower is better)")
print("==================================================")