FreeArt3D / evaluate_partnet.py
MorPhLingXD's picture
Upload demo
bf912fe
import argparse
import tqdm
import os
import numpy as np
import os.path as osp
import json
import random
import torch
import trimesh
import mathutils
from eval_utils.neucon_eval_utils import eval_mesh
from eval_utils.utils_3d import rotz_np, transform_points, Rt_to_pose, open3d_icp_api
from eval_utils.render import generate_opengl_view_matrices_on_sphere, transform_bpy_mesh, eval_rendering_bpy
from eval_utils.clip import eval_clip_similarity
from eval_utils.joints import Joint, eval_joint
def auto_align_mesh_with_scale_and_rotz(gt_mesh, pred_mesh, name, dbg=False, ICP_POINTS=2000,
scale_min=0.7,scale_max=1.3,evaltime=False):
rot_angles = np.linspace(0, 360, 12)
scales = np.linspace(scale_min, scale_max, 10)
cfgs = []
results, results2 = [], []
for rot_angle in tqdm.tqdm(rot_angles, disable=True):
for scale in tqdm.tqdm(scales, leave=False, disable=True):
cfgs.append((rot_angle, scale))
rotz = Rt_to_pose(rotz_np(np.deg2rad(rot_angle))[0])
tmp_mesh = trimesh.Trimesh(pred_mesh.vertices, pred_mesh.faces)
tmp_mesh.vertices = transform_points(tmp_mesh.vertices, rotz)
tmp_mesh.apply_scale(gt_mesh.bounding_box.extents.max() / tmp_mesh.bounding_box.extents.max())
tmp_mesh.vertices *= scale
tmp_mesh_pts = tmp_mesh.sample(ICP_POINTS)
gt_mesh_pts = gt_mesh.sample(ICP_POINTS)
result = open3d_icp_api(tmp_mesh_pts, gt_mesh_pts, thresh=0.08 * gt_mesh.bounding_box.extents.max(),
return_tsfm_only=False)
result = open3d_icp_api(tmp_mesh_pts, gt_mesh_pts, thresh=0.05 * gt_mesh.bounding_box.extents.max(),
init_Rt=result.transformation,
return_tsfm_only=False)
result2 = open3d_icp_api(gt_mesh_pts, tmp_mesh_pts, thresh=0.08 * gt_mesh.bounding_box.extents.max(),
return_tsfm_only=False)
result2 = open3d_icp_api(gt_mesh_pts, tmp_mesh_pts, thresh=0.05 * gt_mesh.bounding_box.extents.max(),
init_Rt=result2.transformation, return_tsfm_only=False)
results.append(result)
results2.append(result2)
fitnesses = [r.fitness for r in results]
fitnesses2 = [r.fitness for r in results2]
fitnesses_all = np.array(fitnesses) + np.array(fitnesses2)
best_idx = np.argmax(fitnesses_all)
final_tsfm = results[best_idx].transformation
final_rotz, final_scale = cfgs[best_idx]
rotz = Rt_to_pose(rotz_np(np.deg2rad(final_rotz))[0])
pred_mesh.vertices = transform_points(pred_mesh.vertices, rotz)
scale1 = gt_mesh.bounding_box.extents.max() / pred_mesh.bounding_box.extents.max()
pred_mesh.apply_scale(scale1)
pred_mesh.vertices *= final_scale
scale = final_scale * scale1
pred_mesh.vertices = transform_points(pred_mesh.vertices, final_tsfm)
return pred_mesh, scale, final_rotz, final_tsfm
if __name__ == '__main__':
args = argparse.ArgumentParser()
args.add_argument('--dbg', action='store_true', help='Debug mode')
args.add_argument('--dataset_folder', type=str, default='datasets/PartNet')
args.add_argument('--pred_folder', type=str, default='outputs')
args.add_argument('--res_folder', type=str, default='evaluations/PartNet')
args.add_argument('--test_id', type=str, default='100214')
args.add_argument('--num_states', type=int, default=6, help='Number of states')
args.add_argument('--num_cams', type=int, default=5, help='Number of cameras')
args = args.parse_args()
dataset_folder = args.dataset_folder
pred_folder = args.pred_folder
test_id = args.test_id
methods = ['ours']
metrics = ['fscore', 'dist1', 'dist2', 'prec', 'recal']
TOTAL_QPOS_NUM = args.num_states
EACH_QPOS_CAM_NUM = args.num_cams
os.makedirs(args.res_folder, exist_ok=True)
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
has_aligned = False
# Sample cameras for rendering
camera_distance = 1.7
camera_poses = generate_opengl_view_matrices_on_sphere(TOTAL_QPOS_NUM * EACH_QPOS_CAM_NUM, camera_distance)
camera_sources = []
for i, camera_pose in enumerate(camera_poses):
inv_view = np.linalg.inv(camera_pose)
camera_pos = inv_view[:3, 3]
if i % EACH_QPOS_CAM_NUM == 0:
azi = np.deg2rad(15)
height = np.sin(np.deg2rad(20)) * camera_distance
x_pos = np.sin(azi) * camera_distance
y_pos = -np.cos(azi) * camera_distance
camera_pos = [x_pos, y_pos, height]
camera_sources.append(mathutils.Vector((camera_pos[0], camera_pos[1], camera_pos[2])))
res = {}
res_path = osp.join(args.res_folder, f'{test_id}.json')
align_path = f'{args.res_folder}/aligns/{test_id}.json'
render_path = f'{args.res_folder}/renderings/{test_id}'
os.makedirs(osp.dirname(align_path), exist_ok=True)
os.makedirs(osp.dirname(render_path), exist_ok=True)
if os.path.exists(align_path):
with open(align_path, 'r') as f:
align_json = json.load(f)
has_aligned = True
else:
align_json = {}
if not has_aligned:
align_json['gt'] = []
for method in methods:
print(f"Evaluating {test_id} with {method}")
item_res = {}
if not has_aligned:
align_json[method] = []
# Load joint info
gt_joint_info_path = osp.join(dataset_folder, test_id, 'joint_info.json')
with open(gt_joint_info_path, 'r') as f:
gt_joint_info = json.load(f)[0]
gt_joint = Joint(gt_joint_info, method='gt')
if method == 'ours':
pred_joint_path = osp.join(pred_folder, test_id, 'sds_output', 'joint_info.json')
else:
pred_joint_path = osp.join(pred_folder, test_id, method, 'joint_info.json')
with open(pred_joint_path, 'r') as f:
pred_joint_info = json.load(f)[0]
pred_joint = Joint(pred_joint_info, method=method)
for qpos_id in tqdm.tqdm(range(TOTAL_QPOS_NUM)):
# Load GT mesh & joint info
gt_mesh_path = osp.join(dataset_folder, test_id, 'gt_mesh', f'{TOTAL_QPOS_NUM - 1 - qpos_id:02d}.glb')
gt_mesh: trimesh.Trimesh = trimesh.load(gt_mesh_path, force='mesh')
gt_mesh_extents = 1 / gt_mesh.extents.max()
gt_mesh.apply_scale(1 / gt_mesh.extents.max())
gt_mesh_centroid = gt_mesh.centroid
gt_mesh.apply_translation(-gt_mesh.centroid)
transform_bpy_mesh(gt_mesh_path, f'{render_path}/gt/gt_mesh_aligned_{qpos_id:02d}.glb',
np.array(gt_mesh_extents), np.array(gt_mesh_centroid), gt=True)
if qpos_id == 0: # Use the qpos=1 state to align the joints
# print(f"[GT Joint] Before alignment: {gt_joint.axis_orig}, {gt_joint.axis_dir}")
gt_joint.apply_scale(gt_mesh_extents)
gt_joint.apply_translation(-gt_mesh_centroid)
# print(f"[GT Joint] After alignment: {gt_joint.axis_orig}, {gt_joint.axis_dir}")
# Render GT mesh
if not os.path.exists(f'{render_path}/gt/qpos_{qpos_id:02d}/cam_04.png'):
gt_images = eval_rendering_bpy(camera_sources[qpos_id*EACH_QPOS_CAM_NUM:(qpos_id+1)*EACH_QPOS_CAM_NUM],
render_path, 'gt', qpos_id, f'{render_path}/gt/gt_mesh_aligned_{qpos_id:02d}.glb')
gt_rendered_flag = True
if not has_aligned and len(align_json['gt']) < TOTAL_QPOS_NUM:
align_json['gt'].append({'scale1': gt_mesh_extents.tolist(),
'translation1': gt_mesh_centroid.tolist()})
if method == 'ours': # Default to fetch fromoutput folder
pred_mesh_path = osp.join(pred_folder, test_id,
'sds_output', 'states', f'qpos_{TOTAL_QPOS_NUM - 1 - qpos_id:02d}.glb')
else:
pred_mesh_path = osp.join(pred_folder, test_id, method,
'sds_output', 'states', f'qpos_{TOTAL_QPOS_NUM - 1 - qpos_id:02d}.glb')
if not osp.exists(pred_mesh_path):
print(f"Prediction mesh {pred_mesh_path} does not exist")
continue
pred_mesh: trimesh.Trimesh = trimesh.load(pred_mesh_path, force='mesh')
pred_mesh_extents = 1 / pred_mesh.extents.max()
pred_mesh.apply_scale(1 / pred_mesh.extents.max())
pred_mesh_centroid = pred_mesh.centroid
pred_mesh.apply_translation(-pred_mesh.centroid)
if not has_aligned:
pred_mesh, scale, final_rotz, final_tsfm = auto_align_mesh_with_scale_and_rotz(gt_mesh, pred_mesh, name=f"mesh_ours_{qpos_id:02d}",
scale_min=0.5, scale_max=1.5)
rotz = Rt_to_pose(rotz_np(np.deg2rad(final_rotz))[0])
else:
scale = align_json[method][qpos_id]['scale2']
rotz = align_json[method][qpos_id]['rotz']
final_tsfm = align_json[method][qpos_id]['final_tsfm']
pred_mesh.apply_transform(rotz)
pred_mesh.apply_scale(scale)
pred_mesh.apply_transform(final_tsfm)
if not has_aligned:
align_json[method].append({'scale1': pred_mesh_extents.tolist(),
'translation1': pred_mesh_centroid.tolist(),
'scale2': scale.tolist(),
'rotz': rotz.tolist(),
'final_tsfm': final_tsfm.tolist()})
# Render pred mesh
transform_bpy_mesh(pred_mesh_path, f'{render_path}/{method}/{method}_mesh_aligned_{qpos_id:02d}.glb',
np.array(pred_mesh_extents), np.array(pred_mesh_centroid), rotz=np.array(rotz),
final_tsfm=np.array(final_tsfm), scale=np.array(scale), gt=False)
pred_images = eval_rendering_bpy(camera_sources[qpos_id*EACH_QPOS_CAM_NUM:(qpos_id+1)*EACH_QPOS_CAM_NUM],
render_path, method, qpos_id, f'{render_path}/{method}/{method}_mesh_aligned_{qpos_id:02d}.glb')
if qpos_id == 0:
# print(f"[Pred Joint] Before alignment: {pred_joint.axis_orig}, {pred_joint.axis_dir}")
pred_joint.apply_scale(pred_mesh_extents)
pred_joint.apply_translation(-pred_mesh_centroid)
pred_joint.apply_transform(np.array(rotz))
pred_joint.apply_scale(np.array(scale))
pred_joint.apply_transform(np.array(final_tsfm))
# print(f"[Pred Joint] After alignment: {pred_joint.axis_orig}, {pred_joint.axis_dir}")
# Evaluate geometric metrics
results = eval_mesh(pred_mesh, gt_mesh, threshold=.05 * 1.0, down_sample=None)
for metric in metrics:
if metric not in item_res:
item_res[metric] = []
item_res[metric].append(results[metric].item())
# Compute mean geometric metrics
fail_flag = True if len(item_res) == 0 else False
for metric in metrics:
item_res[metric] = np.mean(item_res[metric]) if not fail_flag else np.nan
# Evaluate clip similarity
mean_clip_sim = eval_clip_similarity(render_path, method, TOTAL_QPOS_NUM, EACH_QPOS_CAM_NUM)
item_res['clip_sim'] = mean_clip_sim
# Evaluate joint metrics
joint_res = eval_joint(pred_joint, gt_joint)
item_res['joint_axis_err'] = joint_res['joint_axis_err']
item_res['joint_orig_err'] = joint_res['joint_orig_err']
# Update results
res[test_id] = {method: item_res} if test_id not in res else {**res[test_id], **{method: item_res}}
# Save alignments
os.makedirs(osp.dirname(align_path), exist_ok=True)
with open(align_path, 'w') as f:
json.dump(align_json, f, indent=4)
# Save results as json
with open(res_path, "w") as f:
json.dump(res, f, indent=4)