| |
|
|
| import os |
| from pathlib import Path |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap |
| from prettytable import PrettyTable |
| from sklearn.metrics import roc_curve, auc |
|
|
| image_path = "/data/anxiang/IJB_release/IJBC" |
| files = [ |
| "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" |
| ] |
|
|
|
|
| def read_template_pair_list(path): |
| pairs = pd.read_csv(path, sep=' ', header=None).values |
| t1 = pairs[:, 0].astype(np.int) |
| t2 = pairs[:, 1].astype(np.int) |
| label = pairs[:, 2].astype(np.int) |
| return t1, t2, label |
|
|
|
|
| p1, p2, label = read_template_pair_list( |
| os.path.join('%s/meta' % image_path, |
| '%s_template_pair_label.txt' % 'ijbc')) |
|
|
| methods = [] |
| scores = [] |
| for file in files: |
| methods.append(file.split('/')[-2]) |
| scores.append(np.load(file)) |
|
|
| methods = np.array(methods) |
| scores = dict(zip(methods, scores)) |
| colours = dict( |
| zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) |
| x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] |
| tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) |
| fig = plt.figure() |
| for method in methods: |
| fpr, tpr, _ = roc_curve(label, scores[method]) |
| roc_auc = auc(fpr, tpr) |
| fpr = np.flipud(fpr) |
| tpr = np.flipud(tpr) |
| plt.plot(fpr, |
| tpr, |
| color=colours[method], |
| lw=1, |
| label=('[%s (AUC = %0.4f %%)]' % |
| (method.split('-')[-1], roc_auc * 100))) |
| tpr_fpr_row = [] |
| tpr_fpr_row.append("%s-%s" % (method, "IJBC")) |
| for fpr_iter in np.arange(len(x_labels)): |
| _, min_index = min( |
| list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) |
| tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) |
| tpr_fpr_table.add_row(tpr_fpr_row) |
| plt.xlim([10 ** -6, 0.1]) |
| plt.ylim([0.3, 1.0]) |
| plt.grid(linestyle='--', linewidth=1) |
| plt.xticks(x_labels) |
| plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) |
| plt.xscale('log') |
| plt.xlabel('False Positive Rate') |
| plt.ylabel('True Positive Rate') |
| plt.title('ROC on IJB') |
| plt.legend(loc="lower right") |
| print(tpr_fpr_table) |
|
|