Spaces:
Running
on
Zero
Running
on
Zero
| import cv2, os | |
| import sys | |
| sys.path.insert(0, '..') | |
| import numpy as np | |
| from PIL import Image | |
| import logging | |
| import pickle | |
| import importlib | |
| from math import floor | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.utils.data | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| import torchvision.datasets as datasets | |
| import torchvision.models as models | |
| from networks import * | |
| import data_utils | |
| from functions import * | |
| from mobilenetv3 import mobilenetv3_large | |
| if not len(sys.argv) == 4: | |
| print('Format:') | |
| print('python lib/test.py config_file test_labels test_images') | |
| exit(0) | |
| experiment_name = sys.argv[1].split('/')[-1][:-3] | |
| data_name = sys.argv[1].split('/')[-2] | |
| test_labels = sys.argv[2] | |
| test_images = sys.argv[3] | |
| config_path = '.experiments.{}.{}'.format(data_name, experiment_name) | |
| my_config = importlib.import_module(config_path, package='PIPNet') | |
| Config = getattr(my_config, 'Config') | |
| cfg = Config() | |
| cfg.experiment_name = experiment_name | |
| cfg.data_name = data_name | |
| os.environ['CUDA_VISIBLE_DEVICES'] = str(cfg.gpu_id) | |
| if not os.path.exists(os.path.join('./logs', cfg.data_name)): | |
| os.mkdir(os.path.join('./logs', cfg.data_name)) | |
| log_dir = os.path.join('./logs', cfg.data_name, cfg.experiment_name) | |
| if not os.path.exists(log_dir): | |
| os.mkdir(log_dir) | |
| logging.basicConfig(filename=os.path.join(log_dir, 'train.log'), level=logging.INFO) | |
| save_dir = os.path.join('./snapshots', cfg.data_name, cfg.experiment_name) | |
| if not os.path.exists(save_dir): | |
| os.mkdir(save_dir) | |
| if cfg.det_head == 'pip': | |
| meanface_indices, reverse_index1, reverse_index2, max_len = get_meanface(os.path.join('data', cfg.data_name, 'meanface.txt'), cfg.num_nb) | |
| if cfg.det_head == 'pip': | |
| if cfg.backbone == 'resnet18': | |
| resnet18 = models.resnet18(pretrained=cfg.pretrained) | |
| net = Pip_resnet18(resnet18, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride) | |
| elif cfg.backbone == 'resnet50': | |
| resnet50 = models.resnet50(pretrained=cfg.pretrained) | |
| net = Pip_resnet50(resnet50, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride) | |
| elif cfg.backbone == 'resnet101': | |
| resnet101 = models.resnet101(pretrained=cfg.pretrained) | |
| net = Pip_resnet101(resnet101, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride) | |
| elif cfg.backbone == 'mobilenet_v2': | |
| mbnet = models.mobilenet_v2(pretrained=cfg.pretrained) | |
| net = Pip_mbnetv2(mbnet, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride) | |
| elif cfg.backbone == 'mobilenet_v3': | |
| mbnet = mobilenetv3_large() | |
| if cfg.pretrained: | |
| mbnet.load_state_dict(torch.load('lib/mobilenetv3-large-1cd25616.pth')) | |
| net = Pip_mbnetv3(mbnet, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride) | |
| else: | |
| print('No such backbone!') | |
| exit(0) | |
| else: | |
| print('No such head:', cfg.det_head) | |
| exit(0) | |
| if cfg.use_gpu: | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| else: | |
| device = torch.device("cpu") | |
| net = net.to(device) | |
| weight_file = os.path.join(save_dir, 'epoch%d.pth' % (cfg.num_epochs-1)) | |
| state_dict = torch.load(weight_file) | |
| net.load_state_dict(state_dict) | |
| normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| preprocess = transforms.Compose([transforms.Resize((cfg.input_size, cfg.input_size)), transforms.ToTensor(), normalize]) | |
| norm_indices = None | |
| if cfg.data_name == 'data_300W' or cfg.data_name == 'data_300W_COFW_WFLW' or cfg.data_name == 'data_300W_CELEBA': | |
| norm_indices = [36, 45] | |
| elif cfg.data_name == 'COFW': | |
| norm_indices = [8, 9] | |
| elif cfg.data_name == 'WFLW': | |
| norm_indices = [60, 72] | |
| elif cfg.data_name == 'AFLW': | |
| pass | |
| elif cfg.data_name == 'LaPa': | |
| norm_indices = [66, 79] | |
| else: | |
| print('No such data!') | |
| exit(0) | |
| labels = get_label(cfg.data_name, test_labels) | |
| nmes_std = [] | |
| nmes_merge = [] | |
| norm = None | |
| time_all = 0 | |
| for label in labels: | |
| image_name = label[0] | |
| lms_gt = label[1] | |
| if cfg.data_name == 'data_300W' or cfg.data_name == 'COFW' or cfg.data_name == 'WFLW' or cfg.data_name == 'data_300W_COFW_WFLW' or cfg.data_name == 'data_300W_CELEBA': | |
| norm = np.linalg.norm(lms_gt.reshape(-1, 2)[norm_indices[0]] - lms_gt.reshape(-1, 2)[norm_indices[1]]) | |
| elif cfg.data_name == 'AFLW': | |
| norm = 1 | |
| image_path = os.path.join('data', cfg.data_name, test_images, image_name) | |
| image = cv2.imread(image_path) | |
| image = cv2.resize(image, (cfg.input_size, cfg.input_size)) | |
| inputs = Image.fromarray(image[:,:,::-1].astype('uint8'), 'RGB') | |
| inputs = preprocess(inputs).unsqueeze(0) | |
| inputs = inputs.to(device) | |
| t1 = time.time() | |
| if cfg.det_head == 'pip': | |
| lms_pred_x, lms_pred_y, lms_pred_nb_x, lms_pred_nb_y, outputs_cls, max_cls = forward_pip(net, inputs, preprocess, cfg.input_size, cfg.net_stride, cfg.num_nb) | |
| else: | |
| print('No such head!') | |
| if cfg.det_head == 'pip': | |
| # merge neighbor predictions | |
| lms_pred = torch.cat((lms_pred_x, lms_pred_y), dim=1).flatten() | |
| tmp_nb_x = lms_pred_nb_x[reverse_index1, reverse_index2].view(cfg.num_lms, max_len) | |
| tmp_nb_y = lms_pred_nb_y[reverse_index1, reverse_index2].view(cfg.num_lms, max_len) | |
| tmp_x = torch.mean(torch.cat((lms_pred_x, tmp_nb_x), dim=1), dim=1).view(-1,1) | |
| tmp_y = torch.mean(torch.cat((lms_pred_y, tmp_nb_y), dim=1), dim=1).view(-1,1) | |
| lms_pred_merge = torch.cat((tmp_x, tmp_y), dim=1).flatten() | |
| t2 = time.time() | |
| time_all += (t2-t1) | |
| if cfg.det_head == 'pip': | |
| lms_pred = lms_pred.cpu().numpy() | |
| lms_pred_merge = lms_pred_merge.cpu().numpy() | |
| nme_std = compute_nme(lms_pred, lms_gt, norm) | |
| nmes_std.append(nme_std) | |
| if cfg.det_head == 'pip': | |
| nme_merge = compute_nme(lms_pred_merge, lms_gt, norm) | |
| nmes_merge.append(nme_merge) | |
| print('Total inference time:', time_all) | |
| print('Image num:', len(labels)) | |
| print('Average inference time:', time_all/len(labels)) | |
| if cfg.det_head == 'pip': | |
| print('nme: {}'.format(np.mean(nmes_merge))) | |
| logging.info('nme: {}'.format(np.mean(nmes_merge))) | |
| fr, auc = compute_fr_and_auc(nmes_merge) | |
| print('fr : {}'.format(fr)) | |
| logging.info('fr : {}'.format(fr)) | |
| print('auc: {}'.format(auc)) | |
| logging.info('auc: {}'.format(auc)) | |