Spaces:
Build error
Build error
| # evaluate test time optimization from refinement | |
| # python src/test_time_optimization/evaluate_ttopt.py --workers 12 --save-images True --config refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar --ttopt-result-name ttoptv6_stanext_v16b | |
| # python src/test_time_optimization/evaluate_ttopt.py --workers 12 --save-images True --config refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar --ttopt-result-name ttoptv6_stanext_v16 | |
| import argparse | |
| import os.path | |
| import json | |
| import numpy as np | |
| import pickle as pkl | |
| from distutils.util import strtobool | |
| import torch | |
| from torch import nn | |
| import torch.backends.cudnn | |
| from torch.nn import DataParallel | |
| from torch.utils.data import DataLoader | |
| import pytorch3d as p3d | |
| from collections import OrderedDict | |
| import glob | |
| from tqdm import tqdm | |
| from dominate import document | |
| from dominate.tags import * | |
| from PIL import Image | |
| from matplotlib import pyplot as plt | |
| import trimesh | |
| import cv2 | |
| import shutil | |
| from pytorch3d.structures import Meshes | |
| from pytorch3d.loss import mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency | |
| import sys | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) | |
| from combined_model.train_main_image_to_3d_wbr_withref import do_validation_epoch | |
| # from combined_model.model_shape_v7 import ModelImageTo3d_withshape_withproj | |
| # from combined_model.model_shape_v7_withref import ModelImageTo3d_withshape_withproj | |
| from combined_model.model_shape_v7_withref_withgraphcnn import ModelImageTo3d_withshape_withproj | |
| from combined_model.loss_image_to_3d_withbreedrel import Loss | |
| from combined_model.loss_image_to_3d_refinement import LossRef | |
| from configs.barc_cfg_defaults import get_cfg_defaults, update_cfg_global_with_yaml, get_cfg_global_updated | |
| from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d # , batch_rot2aa, geodesic_loss_R | |
| # from test_time_optimization.utils_ttopt import get_evaluation_dataset, get_norm_dict | |
| from stacked_hourglass.datasets.utils_dataset_selection import get_evaluation_dataset, get_norm_dict | |
| from test_time_optimization.bite_inference_model_for_ttopt import BITEInferenceModel | |
| from smal_pytorch.smal_model.smal_torch_new import SMAL | |
| from configs.SMAL_configs import SMAL_MODEL_CONFIG | |
| from smal_pytorch.renderer.differentiable_renderer import SilhRenderer | |
| from test_time_optimization.utils.utils_ttopt import reset_loss_values, get_optimed_pose_with_glob | |
| from combined_model.loss_utils.loss_utils import leg_sideway_error, leg_torsion_error, tail_sideway_error, tail_torsion_error, spine_torsion_error, spine_sideway_error | |
| from combined_model.loss_utils.loss_utils_gc import LossGConMesh, calculate_plane_errors_batch | |
| from combined_model.loss_utils.loss_arap import Arap_Loss | |
| from combined_model.loss_utils.loss_laplacian_mesh_comparison import LaplacianCTF # (coarse to fine animal) | |
| from graph_networks import graphcmr # .utils_mesh import Mesh | |
| from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image | |
| from metrics.metrics import Metrics | |
| from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS | |
| ROOT_LOSS_WEIGH_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/configs/ttopt_loss_weights/' | |
| def main(args): | |
| # load configs | |
| # step 1: load default configs | |
| # step 2: load updates from .yaml file | |
| path_config = os.path.join(get_cfg_defaults().barc_dir, 'src', 'configs', args.config) | |
| update_cfg_global_with_yaml(path_config) | |
| cfg = get_cfg_global_updated() | |
| pck_thresh = 0.15 | |
| print('pck_thresh: ' + str(pck_thresh)) | |
| ROOT_IN_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/results/results_ttopt/' + args.ttopt_result_name + '/' # ttoptv6_debug_x8/' | |
| ROOT_IN_PATH_DETAIL = ROOT_IN_PATH + 'details/' | |
| ROOT_OUT_PATH = ROOT_IN_PATH + 'evaluation/' | |
| if not os.path.exists(ROOT_OUT_PATH): os.makedirs(ROOT_OUT_PATH) | |
| # NEW!!! | |
| logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'] | |
| # logscale_part_list = ['front_legs_l', 'front_legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l', 'back_legs_l', 'back_legs_f'] | |
| # Select the hardware device to use for training. | |
| if torch.cuda.is_available() and cfg.device=='cuda': | |
| device = torch.device('cuda', torch.cuda.current_device()) | |
| torch.backends.cudnn.benchmark = False # True | |
| else: | |
| device = torch.device('cpu') | |
| print('structure_pose_net: ' + cfg.params.STRUCTURE_POSE_NET) | |
| print('refinement network type: ' + cfg.params.REF_NET_TYPE) | |
| print('smal_model_type: ' + cfg.smal.SMAL_MODEL_TYPE) | |
| path_model_file_complete = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, args.model_file_complete) | |
| # Disable gradient calculations. | |
| # torch.set_grad_enabled(False) | |
| # prepare dataset and dataset loadr | |
| val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints = get_evaluation_dataset(cfg.data.DATASET, cfg.data.VAL_OPT, cfg.data.V12, cfg.optim.BATCH_SIZE, args.workers) | |
| len_data = len_val_dataset | |
| # summarize information for normalization | |
| norm_dict = get_norm_dict(stanext_data_info, device) | |
| # prepare complete model | |
| bite_model = BITEInferenceModel(cfg, path_model_file_complete, norm_dict) | |
| # smal_model_type = bite_model.complete_model.smal.smal_model_type | |
| smal_model_type = bite_model.smal_model_type | |
| smal = SMAL(smal_model_type=smal_model_type, template_name='neutral', logscale_part_list=logscale_part_list).to(device) | |
| silh_renderer = SilhRenderer(image_size=256).to(device) | |
| # ---------------------------------------------------------------------------------- | |
| summary = {} | |
| summary['pck'] = np.zeros((len_data)) | |
| summary['pck_by_part'] = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS} | |
| summary['acc_sil_2d'] = np.zeros(len_data) | |
| # Put the model in training mode. | |
| # model.train() | |
| # prepare progress bar | |
| iterable = enumerate(val_loader) | |
| progress = None | |
| if True: # not quiet: | |
| progress = tqdm(iterable, desc='Train', total=len(val_loader), ascii=True, leave=False) | |
| iterable = progress | |
| ind_img_tot = 0 | |
| # prepare variables, put them on the right device | |
| my_step = 0 | |
| batch_size = cfg.optim.BATCH_SIZE | |
| for index, (input, target_dict) in iterable: | |
| for key in target_dict.keys(): | |
| if key == 'breed_index': | |
| target_dict[key] = target_dict[key].long().to(device) | |
| elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']: | |
| target_dict[key] = target_dict[key].float().to(device) | |
| elif key == 'has_seg': | |
| target_dict[key] = target_dict[key].to(device) | |
| else: | |
| pass | |
| input = input.float().to(device) | |
| # get starting values for the optimization | |
| # -> here from barc, but could also be saved and loaded | |
| preds_dict = bite_model.get_all_results(input) | |
| res_normal_and_ref = bite_model.get_selected_results(preds_dict=preds_dict, result_networks=['normal', 'ref']) | |
| res = bite_model.get_selected_results(preds_dict=preds_dict, result_networks=['ref'])['ref'] | |
| # -------------------------------------------------------------------- | |
| # ind_img = 0 | |
| batch_verts_smal = [] | |
| batch_faces_prep = [] | |
| batch_optimed_camera_flength = [] | |
| for ind_img in range(input.shape[0]): | |
| name = (test_name_list[target_dict['index'][ind_img].long()]).replace('/', '__').split('.')[0] | |
| print('ind_img_tot: ' + str(ind_img_tot) + ' -> ' + name) | |
| ind_img_tot += 1 | |
| e_name = 'e000' # 'e300' | |
| npy_file = ROOT_IN_PATH_DETAIL + name + '_flength_' + e_name +'.npy' | |
| flength = np.load(npy_file) | |
| optimed_camera_flength = torch.tensor(flength, device=device) | |
| obj_file = ROOT_IN_PATH + name + '_res_' + e_name +'.obj' | |
| verts, faces, aux = p3d.io.load_obj(obj_file) | |
| verts_smal = verts[None, ...].to(device) | |
| faces_prep = faces.verts_idx[None, ...].to(device) | |
| batch_verts_smal.append(verts_smal) | |
| batch_faces_prep.append(faces_prep) | |
| batch_optimed_camera_flength.append(optimed_camera_flength) | |
| # import pdb; pdb.set_trace() | |
| verts_smal = torch.cat(batch_verts_smal, dim=0) | |
| faces_prep = torch.cat(batch_faces_prep, dim=0) | |
| optimed_camera_flength = torch.cat(batch_optimed_camera_flength, dim=0) | |
| # get keypoint locations from mesh vertices | |
| keyp_3d = smal.get_joints_from_verts(verts_smal, keyp_conf='olive') | |
| # render silhouette and keypoints | |
| pred_silh_images, pred_keyp_raw = silh_renderer(vertices=verts_smal, points=keyp_3d, faces=faces_prep, focal_lengths=optimed_camera_flength) | |
| pred_keyp = pred_keyp_raw[:, :24, :] | |
| # --------------- calculate iou and pck values -------------------- | |
| gt_keypoints_256 = target_dict['tpts'][:, :, :2] / 64. * (256. - 1) | |
| gt_keypoints = torch.cat((gt_keypoints_256, target_dict['tpts'][:, :, 2:3]), dim=2) | |
| # prepare silhouette for IoU calculation - predicted as well as ground truth | |
| has_seg = target_dict['has_seg'] | |
| img_border_mask = target_dict['img_border_mask'][:, 0, :, :] | |
| gtseg = target_dict['silh'] | |
| synth_silhouettes = pred_silh_images[:, 0, :, :] # pred_silh[:, 0, :, :] # output_reproj['silh'] | |
| synth_silhouettes[synth_silhouettes>0.5] = 1 | |
| synth_silhouettes[synth_silhouettes<0.5] = 0 | |
| # calculate PCK as well as IoU (similar to WLDO) | |
| preds = {} | |
| preds['acc_PCK'] = Metrics.PCK( | |
| pred_keyp, gt_keypoints, | |
| gtseg, has_seg, idxs=EVAL_KEYPOINTS, | |
| thresh_range=[pck_thresh], # [0.15], | |
| ) | |
| preds['acc_IOU'] = Metrics.IOU( | |
| synth_silhouettes, gtseg, | |
| img_border_mask, mask=has_seg | |
| ) | |
| for group, group_kps in KEYPOINT_GROUPS.items(): | |
| preds[f'{group}_PCK'] = Metrics.PCK( | |
| pred_keyp, gt_keypoints, gtseg, has_seg, | |
| thresh_range=[pck_thresh], # [0.15], | |
| idxs=group_kps | |
| ) | |
| curr_batch_size = pred_keyp.shape[0] | |
| if not (preds['acc_PCK'].data.cpu().numpy().shape == (summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size]).shape): | |
| import pdb; pdb.set_trace() | |
| summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy() | |
| summary['acc_sil_2d'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy() | |
| for part in summary['pck_by_part']: | |
| summary['pck_by_part'][part][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy() | |
| my_step += 1 | |
| # import pdb; pdb.set_trace() | |
| iou = np.nanmean(summary['acc_sil_2d']) | |
| pck = np.nanmean(summary['pck']) | |
| pck_legs = np.nanmean(summary['pck_by_part']['legs']) | |
| pck_tail = np.nanmean(summary['pck_by_part']['tail']) | |
| pck_ears = np.nanmean(summary['pck_by_part']['ears']) | |
| pck_face = np.nanmean(summary['pck_by_part']['face']) | |
| print('------------------------------------------------') | |
| print("iou: {:.2f}".format(iou*100)) | |
| print(' ') | |
| print("pck: {:.2f}".format(pck*100)) | |
| print(' ') | |
| print("pck_legs: {:.2f}".format(pck_legs*100)) | |
| print("pck_tail: {:.2f}".format(pck_tail*100)) | |
| print("pck_ears: {:.2f}".format(pck_ears*100)) | |
| print("pck_face: {:.2f}".format(pck_face*100)) | |
| print('------------------------------------------------') | |
| # save results in a .txt file | |
| with open(ROOT_OUT_PATH + "a_evaluation_" + e_name + ".txt", "w") as text_file: | |
| print("iou: {:.2f}".format(iou*100), file=text_file) | |
| print("pck: {:.2f}".format(pck*100), file=text_file) | |
| print("pck_legs: {:.2f}".format(pck_legs*100), file=text_file) | |
| print("pck_tail: {:.2f}".format(pck_tail*100), file=text_file) | |
| print("pck_ears: {:.2f}".format(pck_ears*100), file=text_file) | |
| print("pck_face: {:.2f}".format(pck_face*100), file=text_file) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Evaluate a stacked hourglass model.') | |
| parser.add_argument('--model-file-complete', default='', type=str, metavar='PATH', | |
| help='path to saved model weights') | |
| parser.add_argument('--ttopt-result-name', default='', type=str, metavar='PATH', | |
| help='path to saved ttopt results') | |
| parser.add_argument('-cg', '--config', default='barc_cfg_test.yaml', type=str, metavar='PATH', | |
| help='name of config file (default: barc_cfg_test.yaml within src/configs folder)') | |
| parser.add_argument('--save-images', default='True', type=lambda x: bool(strtobool(x)), | |
| help='bool indicating if images should be saved') | |
| parser.add_argument('--workers', default=4, type=int, metavar='N', | |
| help='number of data loading workers') | |
| parser.add_argument('--metrics', '-m', metavar='METRICS', default='all', | |
| choices=['all', None], | |
| help='model architecture') | |
| main(parser.parse_args()) | |