Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import tqdm | |
| import os | |
| import numpy as np | |
| import os.path as osp | |
| import json | |
| import random | |
| import torch | |
| import trimesh | |
| import mathutils | |
| from eval_utils.neucon_eval_utils import eval_mesh | |
| from eval_utils.utils_3d import rotz_np, transform_points, Rt_to_pose, open3d_icp_api | |
| from eval_utils.render import generate_opengl_view_matrices_on_sphere, transform_bpy_mesh, eval_rendering_bpy | |
| from eval_utils.clip import eval_clip_similarity | |
| from eval_utils.joints import Joint, eval_joint | |
| def auto_align_mesh_with_scale_and_rotz(gt_mesh, pred_mesh, name, dbg=False, ICP_POINTS=2000, | |
| scale_min=0.7,scale_max=1.3,evaltime=False): | |
| rot_angles = np.linspace(0, 360, 12) | |
| scales = np.linspace(scale_min, scale_max, 10) | |
| cfgs = [] | |
| results, results2 = [], [] | |
| for rot_angle in tqdm.tqdm(rot_angles, disable=True): | |
| for scale in tqdm.tqdm(scales, leave=False, disable=True): | |
| cfgs.append((rot_angle, scale)) | |
| rotz = Rt_to_pose(rotz_np(np.deg2rad(rot_angle))[0]) | |
| tmp_mesh = trimesh.Trimesh(pred_mesh.vertices, pred_mesh.faces) | |
| tmp_mesh.vertices = transform_points(tmp_mesh.vertices, rotz) | |
| tmp_mesh.apply_scale(gt_mesh.bounding_box.extents.max() / tmp_mesh.bounding_box.extents.max()) | |
| tmp_mesh.vertices *= scale | |
| tmp_mesh_pts = tmp_mesh.sample(ICP_POINTS) | |
| gt_mesh_pts = gt_mesh.sample(ICP_POINTS) | |
| result = open3d_icp_api(tmp_mesh_pts, gt_mesh_pts, thresh=0.08 * gt_mesh.bounding_box.extents.max(), | |
| return_tsfm_only=False) | |
| result = open3d_icp_api(tmp_mesh_pts, gt_mesh_pts, thresh=0.05 * gt_mesh.bounding_box.extents.max(), | |
| init_Rt=result.transformation, | |
| return_tsfm_only=False) | |
| result2 = open3d_icp_api(gt_mesh_pts, tmp_mesh_pts, thresh=0.08 * gt_mesh.bounding_box.extents.max(), | |
| return_tsfm_only=False) | |
| result2 = open3d_icp_api(gt_mesh_pts, tmp_mesh_pts, thresh=0.05 * gt_mesh.bounding_box.extents.max(), | |
| init_Rt=result2.transformation, return_tsfm_only=False) | |
| results.append(result) | |
| results2.append(result2) | |
| fitnesses = [r.fitness for r in results] | |
| fitnesses2 = [r.fitness for r in results2] | |
| fitnesses_all = np.array(fitnesses) + np.array(fitnesses2) | |
| best_idx = np.argmax(fitnesses_all) | |
| final_tsfm = results[best_idx].transformation | |
| final_rotz, final_scale = cfgs[best_idx] | |
| rotz = Rt_to_pose(rotz_np(np.deg2rad(final_rotz))[0]) | |
| pred_mesh.vertices = transform_points(pred_mesh.vertices, rotz) | |
| scale1 = gt_mesh.bounding_box.extents.max() / pred_mesh.bounding_box.extents.max() | |
| pred_mesh.apply_scale(scale1) | |
| pred_mesh.vertices *= final_scale | |
| scale = final_scale * scale1 | |
| pred_mesh.vertices = transform_points(pred_mesh.vertices, final_tsfm) | |
| return pred_mesh, scale, final_rotz, final_tsfm | |
| if __name__ == '__main__': | |
| args = argparse.ArgumentParser() | |
| args.add_argument('--dbg', action='store_true', help='Debug mode') | |
| args.add_argument('--dataset_folder', type=str, default='datasets/PartNet') | |
| args.add_argument('--pred_folder', type=str, default='outputs') | |
| args.add_argument('--res_folder', type=str, default='evaluations/PartNet') | |
| args.add_argument('--test_id', type=str, default='100214') | |
| args.add_argument('--num_states', type=int, default=6, help='Number of states') | |
| args.add_argument('--num_cams', type=int, default=5, help='Number of cameras') | |
| args = args.parse_args() | |
| dataset_folder = args.dataset_folder | |
| pred_folder = args.pred_folder | |
| test_id = args.test_id | |
| methods = ['ours'] | |
| metrics = ['fscore', 'dist1', 'dist2', 'prec', 'recal'] | |
| TOTAL_QPOS_NUM = args.num_states | |
| EACH_QPOS_CAM_NUM = args.num_cams | |
| os.makedirs(args.res_folder, exist_ok=True) | |
| np.random.seed(0) | |
| random.seed(0) | |
| torch.manual_seed(0) | |
| has_aligned = False | |
| # Sample cameras for rendering | |
| camera_distance = 1.7 | |
| camera_poses = generate_opengl_view_matrices_on_sphere(TOTAL_QPOS_NUM * EACH_QPOS_CAM_NUM, camera_distance) | |
| camera_sources = [] | |
| for i, camera_pose in enumerate(camera_poses): | |
| inv_view = np.linalg.inv(camera_pose) | |
| camera_pos = inv_view[:3, 3] | |
| if i % EACH_QPOS_CAM_NUM == 0: | |
| azi = np.deg2rad(15) | |
| height = np.sin(np.deg2rad(20)) * camera_distance | |
| x_pos = np.sin(azi) * camera_distance | |
| y_pos = -np.cos(azi) * camera_distance | |
| camera_pos = [x_pos, y_pos, height] | |
| camera_sources.append(mathutils.Vector((camera_pos[0], camera_pos[1], camera_pos[2]))) | |
| res = {} | |
| res_path = osp.join(args.res_folder, f'{test_id}.json') | |
| align_path = f'{args.res_folder}/aligns/{test_id}.json' | |
| render_path = f'{args.res_folder}/renderings/{test_id}' | |
| os.makedirs(osp.dirname(align_path), exist_ok=True) | |
| os.makedirs(osp.dirname(render_path), exist_ok=True) | |
| if os.path.exists(align_path): | |
| with open(align_path, 'r') as f: | |
| align_json = json.load(f) | |
| has_aligned = True | |
| else: | |
| align_json = {} | |
| if not has_aligned: | |
| align_json['gt'] = [] | |
| for method in methods: | |
| print(f"Evaluating {test_id} with {method}") | |
| item_res = {} | |
| if not has_aligned: | |
| align_json[method] = [] | |
| # Load joint info | |
| gt_joint_info_path = osp.join(dataset_folder, test_id, 'joint_info.json') | |
| with open(gt_joint_info_path, 'r') as f: | |
| gt_joint_info = json.load(f)[0] | |
| gt_joint = Joint(gt_joint_info, method='gt') | |
| if method == 'ours': | |
| pred_joint_path = osp.join(pred_folder, test_id, 'sds_output', 'joint_info.json') | |
| else: | |
| pred_joint_path = osp.join(pred_folder, test_id, method, 'joint_info.json') | |
| with open(pred_joint_path, 'r') as f: | |
| pred_joint_info = json.load(f)[0] | |
| pred_joint = Joint(pred_joint_info, method=method) | |
| for qpos_id in tqdm.tqdm(range(TOTAL_QPOS_NUM)): | |
| # Load GT mesh & joint info | |
| gt_mesh_path = osp.join(dataset_folder, test_id, 'gt_mesh', f'{TOTAL_QPOS_NUM - 1 - qpos_id:02d}.glb') | |
| gt_mesh: trimesh.Trimesh = trimesh.load(gt_mesh_path, force='mesh') | |
| gt_mesh_extents = 1 / gt_mesh.extents.max() | |
| gt_mesh.apply_scale(1 / gt_mesh.extents.max()) | |
| gt_mesh_centroid = gt_mesh.centroid | |
| gt_mesh.apply_translation(-gt_mesh.centroid) | |
| transform_bpy_mesh(gt_mesh_path, f'{render_path}/gt/gt_mesh_aligned_{qpos_id:02d}.glb', | |
| np.array(gt_mesh_extents), np.array(gt_mesh_centroid), gt=True) | |
| if qpos_id == 0: # Use the qpos=1 state to align the joints | |
| # print(f"[GT Joint] Before alignment: {gt_joint.axis_orig}, {gt_joint.axis_dir}") | |
| gt_joint.apply_scale(gt_mesh_extents) | |
| gt_joint.apply_translation(-gt_mesh_centroid) | |
| # print(f"[GT Joint] After alignment: {gt_joint.axis_orig}, {gt_joint.axis_dir}") | |
| # Render GT mesh | |
| if not os.path.exists(f'{render_path}/gt/qpos_{qpos_id:02d}/cam_04.png'): | |
| gt_images = eval_rendering_bpy(camera_sources[qpos_id*EACH_QPOS_CAM_NUM:(qpos_id+1)*EACH_QPOS_CAM_NUM], | |
| render_path, 'gt', qpos_id, f'{render_path}/gt/gt_mesh_aligned_{qpos_id:02d}.glb') | |
| gt_rendered_flag = True | |
| if not has_aligned and len(align_json['gt']) < TOTAL_QPOS_NUM: | |
| align_json['gt'].append({'scale1': gt_mesh_extents.tolist(), | |
| 'translation1': gt_mesh_centroid.tolist()}) | |
| if method == 'ours': # Default to fetch fromoutput folder | |
| pred_mesh_path = osp.join(pred_folder, test_id, | |
| 'sds_output', 'states', f'qpos_{TOTAL_QPOS_NUM - 1 - qpos_id:02d}.glb') | |
| else: | |
| pred_mesh_path = osp.join(pred_folder, test_id, method, | |
| 'sds_output', 'states', f'qpos_{TOTAL_QPOS_NUM - 1 - qpos_id:02d}.glb') | |
| if not osp.exists(pred_mesh_path): | |
| print(f"Prediction mesh {pred_mesh_path} does not exist") | |
| continue | |
| pred_mesh: trimesh.Trimesh = trimesh.load(pred_mesh_path, force='mesh') | |
| pred_mesh_extents = 1 / pred_mesh.extents.max() | |
| pred_mesh.apply_scale(1 / pred_mesh.extents.max()) | |
| pred_mesh_centroid = pred_mesh.centroid | |
| pred_mesh.apply_translation(-pred_mesh.centroid) | |
| if not has_aligned: | |
| pred_mesh, scale, final_rotz, final_tsfm = auto_align_mesh_with_scale_and_rotz(gt_mesh, pred_mesh, name=f"mesh_ours_{qpos_id:02d}", | |
| scale_min=0.5, scale_max=1.5) | |
| rotz = Rt_to_pose(rotz_np(np.deg2rad(final_rotz))[0]) | |
| else: | |
| scale = align_json[method][qpos_id]['scale2'] | |
| rotz = align_json[method][qpos_id]['rotz'] | |
| final_tsfm = align_json[method][qpos_id]['final_tsfm'] | |
| pred_mesh.apply_transform(rotz) | |
| pred_mesh.apply_scale(scale) | |
| pred_mesh.apply_transform(final_tsfm) | |
| if not has_aligned: | |
| align_json[method].append({'scale1': pred_mesh_extents.tolist(), | |
| 'translation1': pred_mesh_centroid.tolist(), | |
| 'scale2': scale.tolist(), | |
| 'rotz': rotz.tolist(), | |
| 'final_tsfm': final_tsfm.tolist()}) | |
| # Render pred mesh | |
| transform_bpy_mesh(pred_mesh_path, f'{render_path}/{method}/{method}_mesh_aligned_{qpos_id:02d}.glb', | |
| np.array(pred_mesh_extents), np.array(pred_mesh_centroid), rotz=np.array(rotz), | |
| final_tsfm=np.array(final_tsfm), scale=np.array(scale), gt=False) | |
| pred_images = eval_rendering_bpy(camera_sources[qpos_id*EACH_QPOS_CAM_NUM:(qpos_id+1)*EACH_QPOS_CAM_NUM], | |
| render_path, method, qpos_id, f'{render_path}/{method}/{method}_mesh_aligned_{qpos_id:02d}.glb') | |
| if qpos_id == 0: | |
| # print(f"[Pred Joint] Before alignment: {pred_joint.axis_orig}, {pred_joint.axis_dir}") | |
| pred_joint.apply_scale(pred_mesh_extents) | |
| pred_joint.apply_translation(-pred_mesh_centroid) | |
| pred_joint.apply_transform(np.array(rotz)) | |
| pred_joint.apply_scale(np.array(scale)) | |
| pred_joint.apply_transform(np.array(final_tsfm)) | |
| # print(f"[Pred Joint] After alignment: {pred_joint.axis_orig}, {pred_joint.axis_dir}") | |
| # Evaluate geometric metrics | |
| results = eval_mesh(pred_mesh, gt_mesh, threshold=.05 * 1.0, down_sample=None) | |
| for metric in metrics: | |
| if metric not in item_res: | |
| item_res[metric] = [] | |
| item_res[metric].append(results[metric].item()) | |
| # Compute mean geometric metrics | |
| fail_flag = True if len(item_res) == 0 else False | |
| for metric in metrics: | |
| item_res[metric] = np.mean(item_res[metric]) if not fail_flag else np.nan | |
| # Evaluate clip similarity | |
| mean_clip_sim = eval_clip_similarity(render_path, method, TOTAL_QPOS_NUM, EACH_QPOS_CAM_NUM) | |
| item_res['clip_sim'] = mean_clip_sim | |
| # Evaluate joint metrics | |
| joint_res = eval_joint(pred_joint, gt_joint) | |
| item_res['joint_axis_err'] = joint_res['joint_axis_err'] | |
| item_res['joint_orig_err'] = joint_res['joint_orig_err'] | |
| # Update results | |
| res[test_id] = {method: item_res} if test_id not in res else {**res[test_id], **{method: item_res}} | |
| # Save alignments | |
| os.makedirs(osp.dirname(align_path), exist_ok=True) | |
| with open(align_path, 'w') as f: | |
| json.dump(align_json, f, indent=4) | |
| # Save results as json | |
| with open(res_path, "w") as f: | |
| json.dump(res, f, indent=4) | |