ndjadjafbagk / metric_cd.py
udbbdh's picture
Upload folder using huggingface_hub
7340df2 verified
import os
import argparse
import numpy as np
import torch
import trimesh
from tqdm import tqdm
# =====================================================
# 🔹 Mesh归一化函数
# =====================================================
def normalize_to_unit_sphere(mesh: trimesh.Trimesh) -> trimesh.Trimesh:
"""将mesh平移到原点并缩放到单位球内"""
vertices = mesh.vertices
centroid = vertices.mean(axis=0)
vertices = vertices - centroid
scale = np.max(np.linalg.norm(vertices, axis=1))
vertices = vertices / scale
mesh.vertices = vertices
return mesh
def normalize_to_unit_cube(mesh: trimesh.Trimesh) -> trimesh.Trimesh:
"""将mesh平移并缩放到[-1,1]^3单位立方体内"""
bbox_min, bbox_max = mesh.bounds
center = (bbox_min + bbox_max) / 2
scale = (bbox_max - bbox_min).max() / 2
mesh.vertices = (mesh.vertices - center) / scale
return mesh
# =====================================================
# 🔹 点云采样函数 + 返回面片数
# =====================================================
def sample_points_from_mesh(mesh_path: str, num_points: int, normalize: str = "none"):
"""
从mesh文件采样点云,并可选归一化。
返回: (points: Tensor, face_count: int)
"""
try:
mesh = trimesh.load(mesh_path, force='mesh', process=False)
if normalize == "sphere":
mesh = normalize_to_unit_sphere(mesh)
elif normalize == "cube":
mesh = normalize_to_unit_cube(mesh)
points, _ = trimesh.sample.sample_surface(mesh, num_points)
face_count = len(mesh.faces)
return torch.from_numpy(points).float(), face_count
except Exception as e:
print(f"[-] 警告:加载或采样文件失败 {mesh_path}。错误: {e}")
return None, 0
# =====================================================
# 🔹 Chamfer Distance 计算函数
# =====================================================
def find_minimum_cd_batched(gen_pc: torch.Tensor, gt_pcs_batch: torch.Tensor):
"""计算生成点云到一批GT点云的最小CD及对应索引"""
gen_pc_batch = gen_pc.unsqueeze(0).expand(gt_pcs_batch.size(0), -1, -1)
dist_matrix = torch.cdist(gen_pc_batch, gt_pcs_batch)
min_dist_gen_to_gt = dist_matrix.min(2).values.mean(1)
min_dist_gt_to_gen = dist_matrix.min(1).values.mean(1)
cd_scores_for_one_gen = min_dist_gen_to_gt + min_dist_gt_to_gen
min_cd, min_idx = cd_scores_for_one_gen.min(0)
return min_cd.item(), min_idx.item()
# =====================================================
# 🔹 主流程
# =====================================================
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[*] 使用设备: {device}")
print(f"[*] 归一化模式: {args.normalize}")
# --- Step 1: 加载GT网格并采样 ---
print("[*] 正在预加载并采样所有GT网格...")
gt_files = sorted([f for f in os.listdir(args.gt_dir) if f.endswith(('.obj', '.ply', '.off'))])
if not gt_files:
print(f"[-] 错误: GT目录中未找到mesh文件: {args.gt_dir}")
return
gt_point_clouds, gt_faces_counts = [], []
for gt_filename in tqdm(gt_files, desc="预处理GT网格"):
gt_filepath = os.path.join(args.gt_dir, gt_filename)
pc, fnum = sample_points_from_mesh(gt_filepath, args.num_points, args.normalize)
if pc is not None:
gt_point_clouds.append(pc.to(device))
gt_faces_counts.append(fnum)
if not gt_point_clouds:
print("[-] 错误: 无法从任何GT文件采样点云。")
return
print(f"[*] 成功加载 {len(gt_point_clouds)} 个GT点云。")
# --- Step 2: 遍历生成的网格 ---
gen_files = sorted([f for f in os.listdir(args.generated_dir) if f.endswith(('.obj', '.ply', '.off'))])
if not gen_files:
print(f"[-] 错误: 生成目录中未找到mesh文件: {args.generated_dir}")
return
all_min_cd_scores = []
face_ratios = []
pred_faces_all = []
gt_faces_matched = []
for gen_filename in tqdm(gen_files, desc="评估生成的网格"):
gen_filepath = os.path.join(args.generated_dir, gen_filename)
gen_pc, gen_face_count = sample_points_from_mesh(gen_filepath, args.num_points, args.normalize)
if gen_pc is None:
continue
gen_pc = gen_pc.to(device)
batch_size = args.batch_size
min_cd_for_this_gen = float('inf')
matched_gt_idx = -1
for i in range(0, len(gt_point_clouds), batch_size):
gt_pcs_batch = torch.stack(gt_point_clouds[i:i + batch_size])
min_cd_in_batch, idx_in_batch = find_minimum_cd_batched(gen_pc, gt_pcs_batch)
if min_cd_in_batch < min_cd_for_this_gen:
min_cd_for_this_gen = min_cd_in_batch
matched_gt_idx = i + idx_in_batch
all_min_cd_scores.append(min_cd_for_this_gen)
if matched_gt_idx >= 0:
gt_face_count = gt_faces_counts[matched_gt_idx]
face_ratio = gen_face_count / gt_face_count if gt_face_count > 0 else 0
face_ratios.append(face_ratio)
pred_faces_all.append(gen_face_count)
gt_faces_matched.append(gt_face_count)
if not args.quiet:
print(f" -> {gen_filename}: 最小CD={min_cd_for_this_gen:.6f}, Pred面数={gen_face_count}, GT面数={gt_face_count}, 比值={face_ratio:.3f}")
# --- Step 3: 汇总 ---
if not all_min_cd_scores:
print("\n[-] 评估结束,但没有成功处理任何网格。")
else:
mean_min_cd = np.mean(all_min_cd_scores)
mean_face_ratio = np.mean(face_ratios) if face_ratios else 0
mean_pred_faces = np.mean(pred_faces_all) if pred_faces_all else 0
mean_gt_faces = np.mean(gt_faces_matched) if gt_faces_matched else 0
print("\n" + "="*70)
print(f"[*] 评估完成 (基于最小CD匹配)")
print(f"[*] 共评估 {len(all_min_cd_scores)} 个生成网格")
print(f"[*] 平均最小倒角距离 (Mean Min CD): {mean_min_cd:.6f}")
print(f"[*] 平均Pred面片数: {mean_pred_faces:.1f}")
print(f"[*] 平均GT面片数: {mean_gt_faces:.1f}")
print(f"[*] 平均面片比 (Pred/GT): {mean_face_ratio:.3f}")
print("="*70)
# =====================================================
# 🔹 命令行接口
# =====================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="评估生成mesh与GT集合的最小Chamfer Distance及面片数比")
parser.add_argument("--generated_dir", type=str, required=True, help="生成的mesh文件夹路径")
parser.add_argument("--gt_dir", type=str, required=True, help="GT网格文件夹路径")
parser.add_argument("--num_points", type=int, default=10000, help="每个mesh采样点数")
parser.add_argument("--batch_size", type=int, default=16, help="与多少个GT点云进行批处理比较")
parser.add_argument("--normalize", type=str, default="none", choices=["none", "sphere", "cube"], help="归一化模式: none | sphere | cube")
parser.add_argument("--quiet", action="store_true", help="静默模式,只输出最终平均CD")
args = parser.parse_args()
main(args)
'''
# 不归一化
python metric_cd.py \
--generated_dir /root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs \
--gt_dir /root/Trisf/abalation_post_processing/gt_mesh \
--num_points 4096 \
--normalize none
# 归一化到单位球
python metric_cd.py \
--generated_dir /root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs/0.8_1.5 \
--gt_dir /root/Trisf/abalation_post_processing/gt_mesh \
--num_points 4096 \
--normalize sphere
# 归一化到单位立方体
python metric_cd.py \
--generated_dir /root/Trisf/experiments_edge/train_set/1e-2kl_base/epoch_20_test_set_obj_0gs \
--gt_dir /root/Trisf/abalation_post_processing/gt_mesh \
--num_points 4096 \
--normalize cube
'''